This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 085d36c  [MetaSchedule] Refactor testing workloads (#10497)
085d36c is described below

commit 085d36ca7f86fec31af78a2d7ec924266bf7b147
Author: Junru Shao <[email protected]>
AuthorDate: Sat Mar 5 17:41:50 2022 -0800

    [MetaSchedule] Refactor testing workloads (#10497)
---
 python/tvm/auto_scheduler/relay_integration.py     |  12 +-
 python/tvm/meta_schedule/integration.py            |  15 +-
 python/tvm/meta_schedule/testing/__init__.py       |   3 -
 python/tvm/meta_schedule/testing/byoc_trt.py       |  53 ---
 .../meta_schedule/testing/conv2d_winograd_cpu.py   |   2 +-
 .../meta_schedule/testing/custom_builder_runner.py | 140 +++++++
 python/tvm/meta_schedule/testing/relay_workload.py | 435 +++++++++++++++------
 python/tvm/relay/build_module.py                   |  31 +-
 python/tvm/topi/nn/conv2d.py                       |   8 +-
 src/meta_schedule/integration.cc                   |  22 +-
 src/meta_schedule/schedule_rule/winograd.cc        |   2 +-
 .../unittest/test_meta_schedule_byoc_tensorrt.py   | 187 +++------
 .../unittest/test_meta_schedule_integration.py     |  97 +++--
 tests/python/unittest/test_meta_schedule_runner.py |  20 +-
 .../unittest/test_meta_schedule_tune_relay.py      |  95 ++---
 15 files changed, 683 insertions(+), 439 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py 
b/python/tvm/auto_scheduler/relay_integration.py
index e9bb68a..7ff1840 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs):
     """
 
     # pylint: disable=import-outside-toplevel
-    from tvm.auto_scheduler.measure import (
+    from tvm.auto_scheduler.measure import (  # lazily import to avoid 
recursive dependency
         prepare_input_map,
-    )  # lazily import to avoid recursive dependency
+    )
 
     io_tensors, has_layout_free, has_complex_op = 
traverse_to_get_io_tensors(outs)
     if not io_tensors:  # The compute includes dynamic shapes which are not 
supported yet.
@@ -482,4 +482,10 @@ def is_auto_scheduler_enabled():
     enabled: bool
         Whether the auto-scheduler is enabled
     """
-    return 
PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
+    return PassContext.current().config.get(
+        "relay.backend.use_auto_scheduler",
+        False,
+    ) or PassContext.current().config.get(
+        "relay.backend.use_meta_schedule",
+        False,
+    )
diff --git a/python/tvm/meta_schedule/integration.py 
b/python/tvm/meta_schedule/integration.py
index 178239e..26b0144 100644
--- a/python/tvm/meta_schedule/integration.py
+++ b/python/tvm/meta_schedule/integration.py
@@ -204,10 +204,8 @@ def extract_task_from_relay(
     params: Optional[Dict[str, NDArray]] = None,
     *,
     opt_level: int = 3,
-    pass_config: Dict[str, Any] = {
-        "relay.backend.use_meta_schedule": True,
-    },
-    disabled_pass: List[str] = [],
+    pass_config: Optional[Dict[str, Any]] = None,
+    disabled_pass: Optional[List[str]] = None,
 ) -> List[ExtractedTask]:
     """Extract tuning tasks from a relay program.
 
@@ -221,9 +219,9 @@ def extract_task_from_relay(
         The associated parameters of the program
     opt_level : int
         The optimization level of the compiler
-    pass_config : Dict[str, Any]
+    pass_config : Optional[Dict[str, Any]]
         The pass config of the compiler
-    disabled_pass : List[str]
+    disabled_pass : Optional[List[str]]
         The list of disabled passes of the compiler
 
     Returns
@@ -250,6 +248,11 @@ def extract_task_from_relay(
         thread.start()
         thread.join()
 
+    if disabled_pass is None:
+        disabled_pass = []
+    if pass_config is None:
+        pass_config = {"relay.backend.use_meta_schedule": True}
+
     env = TaskExtraction()
     if isinstance(mod, RelayFunc):
         mod = IRModule.from_expr(mod)
diff --git a/python/tvm/meta_schedule/testing/__init__.py 
b/python/tvm/meta_schedule/testing/__init__.py
index 85b48b3..5d6081f 100644
--- a/python/tvm/meta_schedule/testing/__init__.py
+++ b/python/tvm/meta_schedule/testing/__init__.py
@@ -15,6 +15,3 @@
 # specific language governing permissions and limitations
 # under the License.
 """Testing utilities in meta schedule"""
-from .byoc_trt import relay_build_with_tensorrt
-from .local_rpc import LocalRPC
-from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, 
get_torch_model
diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py 
b/python/tvm/meta_schedule/testing/byoc_trt.py
deleted file mode 100644
index d459518..0000000
--- a/python/tvm/meta_schedule/testing/byoc_trt.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TensorRT-MetaSchedule integration"""
-# pylint: disable=import-outside-toplevel
-
-from typing import List
-import tvm
-from tvm.runtime import Module
-from tvm.meta_schedule.builder import BuilderResult
-from tvm.target import Target
-
-
-def relay_build_with_tensorrt(
-    mod: Module,
-    target: Target,
-    params: dict,
-) -> List[BuilderResult]:
-    """Build a Relay IRModule with TensorRT BYOC
-    Parameters
-    ----------
-    mod : IRModule
-        The Relay IRModule to build.
-    target : Target
-        The target to build the module for.
-    params : Dict[str, NDArray]
-        The parameter dict to build the module with.
-    Returns
-    -------
-    mod : runtime.Module
-        The built module.
-    """
-    from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
-
-    assert isinstance(target, Target)
-    mod, config = partition_for_tensorrt(mod, params)
-    with tvm.transform.PassContext(opt_level=3, 
config={"relay.ext.tensorrt.options": config}):
-        result = tvm.relay.build_module._build_module_no_factory(mod, "cuda", 
"llvm", params)
-    assert isinstance(result, Module)
-    return result
diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py 
b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
index bfd5f45..261768c 100644
--- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
+++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
@@ -79,7 +79,7 @@ def conv2d_winograd_cpu(
             eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                 "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
             )
-            T.block_attr({"schedule_rule": 
"meta_schedule.winograd_data_pack.cpu"})
+            T.block_attr({"schedule_rule": 
"meta_schedule.winograd_data_pack.llvm"})
             T.reads(
                 [
                     data_pack[eps_1, nu_1, p_1, ci_1],
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py 
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
new file mode 100644
index 0000000..87bad5a
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Customized builder and runner methods"""
+# pylint: disable=import-outside-toplevel
+
+from typing import TYPE_CHECKING, Dict, List
+
+if TYPE_CHECKING:
+    from tvm.ir import IRModule
+    from tvm.meta_schedule.runner import EvaluatorConfig
+    from tvm.runtime import Device, Module, NDArray
+    from tvm.target import Target
+
+
+def build_relay(
+    mod: "IRModule",
+    target: "Target",
+    params: Dict[str, "NDArray"],
+) -> "Module":
+    """Build a Relay IRModule
+
+    Parameters
+    ----------
+    mod : IRModule
+        The Relay IRModule to build.
+    target : Target
+        The target to build the module for.
+    params : Dict[str, NDArray]
+        The parameter dict to build the module with.
+
+    Returns
+    -------
+    mod : runtime.Module
+        The built module.
+    """
+    from tvm.relay.build_module import _build_module_no_factory as relay_build
+    from tvm.runtime import Module
+
+    result = relay_build(mod, target=target, target_host=None, params=params)
+    assert isinstance(result, Module)
+    return result
+
+
+def build_relay_with_tensorrt(
+    mod: "IRModule",
+    target: "Target",
+    params: Dict[str, "NDArray"],
+) -> "Module":
+    """Build a Relay IRModule with TensorRT BYOC
+
+    Parameters
+    ----------
+    mod : IRModule
+        The Relay IRModule to build.
+
+    target : Target
+        The target to build the module for.
+
+    params : Dict[str, NDArray]
+        The parameter dict to build the module with.
+
+    Returns
+    -------
+    mod : runtime.Module
+        The built module.
+    """
+    from tvm.ir.transform import PassContext
+    from tvm.relay.build_module import _build_module_no_factory as relay_build
+    from tvm.relay.op.contrib import tensorrt
+    from tvm.runtime import Module
+
+    mod, config = tensorrt.partition_for_tensorrt(mod, params)
+    with PassContext(
+        opt_level=3,
+        config={"relay.ext.tensorrt.options": config},
+    ):
+        result = relay_build(mod, target=target, target_host=None, 
params=params)
+    assert isinstance(result, Module)
+    return result
+
+
+def run_with_graph_executor(
+    rt_mod: "Module",
+    device: "Device",
+    evaluator_config: "EvaluatorConfig",
+    repeated_args: List["NDArray"],
+) -> List[float]:
+    """Run a Relay module with GraphExecutor
+
+    Parameters
+    ----------
+    rt_mod : Module
+        The Relay module to run.
+    device : Device
+        The device to run the module on.
+    evaluator_config : EvaluatorConfig
+        The evaluator configuration to run the module with.
+    repeated_args : List[NDArray]
+        The list of repeated arguments to run the module with.
+
+    Returns
+    -------
+    results : List[float]
+        The list of results.
+    """
+    import itertools
+
+    from tvm.contrib.graph_executor import GraphModule
+
+    graph_mod = GraphModule(rt_mod["default"](device))
+    evaluator = graph_mod.module.time_evaluator(
+        func_name="run",
+        dev=device,
+        number=evaluator_config.number,
+        repeat=evaluator_config.repeat,
+        min_repeat_ms=evaluator_config.min_repeat_ms,
+        f_preproc="cache_flush_cpu_non_first_arg"
+        if evaluator_config.enable_cpu_cache_flush
+        else "",
+    )
+    repeated_costs = []
+    for args in repeated_args:
+        profile_result = evaluator(*args)
+        repeated_costs.append(profile_result.results)
+    costs = [float(cost) for cost in 
itertools.chain.from_iterable(repeated_costs)]
+    return costs
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py 
b/python/tvm/meta_schedule/testing/relay_workload.py
index 2f1ffdd..29cc70a 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -15,154 +15,333 @@
 # specific language governing permissions and limitations
 # under the License.
 """Workloads in Relay IR"""
-from enum import Enum
-from typing import Dict, Tuple
+# pylint: disable=import-outside-toplevel
+import multiprocessing
+import os
+import pickle
+from typing import Any, Dict, List, Optional, Tuple
 
-import tvm.relay.testing  # pylint: disable=unused-import
+import tvm
+import tvm.relay.testing
 from tvm import relay
 from tvm.ir import IRModule
-from tvm.runtime import NDArray
-
-# Model types supported in Torchvision
-class MODEL_TYPE(Enum):  # pylint: disable=invalid-name
-    IMAGE_CLASSIFICATION = (1,)
-    VIDEO_CLASSIFICATION = (2,)
-    SEGMENTATION = (3,)
-    OBJECT_DETECTION = (4,)
-    TEXT_CLASSIFICATION = (5,)
-
-
-# Specify the type of each model
-MODEL_TYPES = {
-    "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
-    "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
-    "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
-}
-
-
-def get_torch_model(
-    model_name: str,
-    input_shape: Tuple[int, ...],
-    output_shape: Tuple[int, int],  # pylint: disable=unused-argument
-    dtype: str = "float32",
-) -> Tuple[IRModule, Dict[str, NDArray]]:
-    """Load model from torch model zoo
-    Parameters
-    ----------
-    model_name : str
-        The name of the model to load
-    input_shape: Tuple[int, ...]
-        Tuple for input shape
-    output_shape: Tuple[int, int]
-        Tuple for output shape
-    dtype: str
-        Tensor data type
-    """
+from tvm.meta_schedule.integration import ExtractedTask, 
extract_task_from_relay
+from tvm.runtime import NDArray, load_param_dict, save_param_dict
+from tvm.target import Target
 
-    assert dtype == "float32"
 
-    import torch  # type: ignore # pylint: 
disable=import-error,import-outside-toplevel
-    from torchvision import models  # type: ignore # pylint: 
disable=import-error,import-outside-toplevel
-    import transformers  # type: ignore # pylint: 
disable=import-error,import-outside-toplevel
-    import os  # type: ignore # pylint: 
disable=import-error,import-outside-toplevel
+def _get_network(
+    args: Tuple[str, List[int]]
+) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]:
+    name: str
+    input_shape: List[int]
+    name, input_shape = args
 
-    def do_trace(model, inp):
-        model.eval()
-        model_trace = torch.jit.trace(model, inp)
-        model_trace.eval()
-        return model_trace
+    mod: IRModule
+
+    if name in [
+        "resnet_18",
+        "resnet_50",
+        "wide_resnet_50",
+        "resnext_50",
+        "mobilenet_v2",
+        "mobilenet_v3",
+        "inception_v3",
+        "densenet_121",
+        "resnet3d_18",
+        "vgg_16",
+    ]:
+        import torch  # type: ignore
+        from torchvision import models  # type: ignore
+
+        if name in ["resnet_18", "resnet_50"]:
+            model = getattr(models, name.replace("_", ""))(pretrained=False)
+        elif name == "wide_resnet_50":
+            model = getattr(models, "wide_resnet50_2")(pretrained=False)
+        elif name == "resnext_50":
+            model = getattr(models, "resnext50_32x4d")(pretrained=False)
+        elif name == "mobilenet_v2":
+            model = getattr(models, name)(pretrained=False)
+        elif name == "mobilenet_v3":
+            model = getattr(models, name + "_large")(pretrained=False)
+        elif name == "inception_v3":
+            model = getattr(models, name)(pretrained=False, aux_logits=False)
+        elif name == "densenet_121":
+            model = getattr(models, name.replace("_", ""))(pretrained=False)
+        elif name == "resnet3d_18":
+            model = models.video.r3d_18(pretrained=False)
+        elif name == "vgg_16":
+            model = getattr(models, name.replace("_", ""))(pretrained=False)
 
-    # Load model from torchvision
-    if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+        dtype = "float32"
+        input_data = torch.randn(input_shape).type(  # pylint: 
disable=no-member
+            {
+                "float32": torch.float32,  # pylint: disable=no-member
+            }[dtype]
+        )
+        scripted_model = torch.jit.trace(model, input_data).eval()
+        input_name = "input0"
+        shape_list = [(input_name, input_shape)]
+        mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+        with tvm.transform.PassContext(opt_level=3):
+            mod = tvm.transform.Sequential(
+                [
+                    relay.transform.RemoveUnusedFunctions(),
+                    relay.transform.ConvertLayout(
+                        {
+                            "nn.conv2d": ["NHWC", "default"],
+                            "nn.conv3d": ["NDHWC", "default"],
+                            "nn.max_pool2d": ["NHWC", "default"],
+                            "nn.avg_pool2d": ["NHWC", "default"],
+                        }
+                    ),
+                ]
+            )(mod)
+        inputs = (input_name, input_shape, dtype)
+    elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
         os.environ["TOKENIZERS_PARALLELISM"] = "false"
-        model = transformers.BertModel(
-            transformers.BertConfig(
+        # pip3 install transformers==3.5 torch==1.7
+        import torch  # type: ignore
+        import transformers  # type: ignore
+
+        config_dict = {
+            "bert_tiny": transformers.BertConfig(
+                num_hidden_layers=6,
+                hidden_size=512,
+                intermediate_size=2048,
+                num_attention_heads=8,
+                return_dict=False,
+            ),
+            "bert_base": transformers.BertConfig(
                 num_hidden_layers=12,
                 hidden_size=768,
                 intermediate_size=3072,
                 num_attention_heads=12,
                 return_dict=False,
-            )
-        )
+            ),
+            "bert_medium": transformers.BertConfig(
+                num_hidden_layers=12,
+                hidden_size=1024,
+                intermediate_size=4096,
+                num_attention_heads=16,
+                return_dict=False,
+            ),
+            "bert_large": transformers.BertConfig(
+                num_hidden_layers=24,
+                hidden_size=1024,
+                intermediate_size=4096,
+                num_attention_heads=16,
+                return_dict=False,
+            ),
+        }
+        configuration = config_dict[name]
+        model = transformers.BertModel(configuration)
+        input_name = "input_ids"
+        input_dtype = "int64"
+        a = torch.randint(10000, input_shape)  # pylint: disable=no-member
         model.eval()
-        input_data = torch.randint(10000, input_shape)
-        shape_list = [("input_ids", input_shape)]
-        scripted_model = torch.jit.trace(model, [input_data], strict=False)
-    elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
-        model = getattr(models, model_name)()
-        # Setup input
-        input_data = torch.randn(input_shape).type(torch.float32)
-        shape_list = [("input0", input_shape)]
-        # Get trace. Depending on the model type, wrapper may be necessary.
-        scripted_model = do_trace(model, input_data)
+        scripted_model = torch.jit.trace(model, [a], strict=False)
+        input_name = "input_ids"
+        shape_list = [(input_name, input_shape)]
+        mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+        mod = relay.transform.FastMath()(mod)
+        mod = relay.transform.CombineParallelBatchMatmul()(mod)
+        inputs = (input_name, input_shape, input_dtype)
+    elif name == "dcgan":
+        output_shape = input_shape
+        batch_size = output_shape[0]
+        oshape = output_shape[1:]
+        mod, params = relay.testing.dcgan.get_workload(
+            batch_size=batch_size,
+            oshape=oshape,
+            layout="NHWC",
+        )
+        inputs = ("data", [100], "float32")
     else:
-        raise ValueError("Unsupported model in Torch model zoo.")
+        raise ValueError("Invalid name: " + name)
+
+    params_bytearray: bytearray = save_param_dict(params)
+    return mod, params_bytearray, inputs
+
+
+def _load_cache(cache_dir: Optional[str], filename: str) -> 
Optional[List[Any]]:
+    if cache_dir is None:
+        return None
+    path = os.path.join(os.path.expanduser(cache_dir), filename)
+    if not os.path.exists(path):
+        return None
+    print(f"Load from cache: {path}")
+    with open(path, "rb") as i_f:
+        return pickle.load(i_f)
 
-    # Convert torch model to relay module
-    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
-    return mod, params
+
+def _save_cache(cache_dir: Optional[str], filename: str, objects: List[Any]) 
-> None:
+    if cache_dir is None:
+        return
+    path = os.path.join(os.path.expanduser(cache_dir), filename)
+    with open(path, "wb") as o_f:
+        pickle.dump(objects, o_f)
 
 
 def get_network(
     name: str,
-    batch_size: int,
-    layout: str = "NHWC",
-    dtype: str = "float32",
-) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, 
int]]:
-    """Get the symbol definition and random weight of a network"""
-    # meta-schedule prefers NHWC layout
-    if layout == "NHWC":
-        image_shape = (224, 224, 3)
-    elif layout == "NCHW":
-        image_shape = (3, 224, 224)
-    else:
-        raise ValueError("Invalid layout: " + layout)
+    input_shape: List[int],
+    *,
+    cache_dir: Optional[str] = None,
+) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]:
+    """Get the symbol definition and random weight of a network
 
-    input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape
-    output_shape: Tuple[int, int] = (batch_size, 1000)
+    Parameters
+    ----------
+    name : str
+        The name of the network.
+    input_shape : List[int]
+        The shape of the input tensor.
+    cache_dir : Optional[str], optional
+        The directory to cache the generated network.
+        If not specified, the cache will be disabled.
 
-    if name.startswith("resnet-"):
-        n_layer = int(name.split("-")[1])
-        mod, params = relay.testing.resnet.get_workload(
-            num_layers=n_layer,
-            batch_size=batch_size,
-            layout=layout,
-            dtype=dtype,
-            image_shape=image_shape,
-        )
-    elif name.startswith("resnet3d-"):
-        n_layer = int(name.split("-")[1])
-        mod, params = relay.testing.resnet.get_workload(
-            num_layers=n_layer,
-            batch_size=batch_size,
-            layout=layout,
-            dtype=dtype,
-            image_shape=image_shape,
-        )
-    elif name == "mobilenet":
-        mod, params = relay.testing.mobilenet.get_workload(
-            batch_size=batch_size, layout=layout, dtype=dtype, 
image_shape=image_shape
-        )
-    elif name == "squeezenet_v1.1":
-        assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
-        mod, params = relay.testing.squeezenet.get_workload(
-            version="1.1",
-            batch_size=batch_size,
-            dtype=dtype,
-            image_shape=image_shape,
-        )
-    elif name == "inception_v3":
-        input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else 
(batch_size, 299, 299, 3)
-        mod, params = 
relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
-    elif name == "mxnet":
-        from mxnet.gluon.model_zoo.vision import get_model  # type: ignore  # 
pylint: disable=import-outside-toplevel
-
-        assert layout == "NCHW"
-        block = get_model("resnet50_v1", pretrained=True)
-        mod, params = relay.frontend.from_mxnet(block, shape={"data": 
input_shape}, dtype=dtype)
-        net = mod["main"]
-        net = relay.Function(
-            net.params, relay.nn.softmax(net.body), None, net.type_params, 
net.attrs
+    Returns
+    -------
+    mod : IRModule
+        The IRModule representing the network.
+    params : Dict[str, NDArray]
+        The parameters of the networks.
+    inputs : Tuple[str, List[int], str]
+        The name, shape and dtype of the input tensor.
+    """
+
+    mod: IRModule
+    params: Dict[str, NDArray]
+    inputs: Tuple[str, List[int], str]
+    params_bytearray: bytearray
+
+    filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
+    cached = _load_cache(cache_dir, filename)
+    if cached is None:
+        with multiprocessing.Pool(processes=1) as pool:
+            result = pool.map(_get_network, [(name, input_shape)])
+        ((mod, params_bytearray, inputs),) = result
+        cached = [mod, params_bytearray, inputs]
+        _save_cache(cache_dir, filename, cached)
+    mod, params_bytearray, inputs = cached
+    params = load_param_dict(params_bytearray)
+    return mod, params, inputs
+
+
+def extract_from_relay(
+    mod: IRModule,
+    target: Target,
+    params: Optional[Dict[str, NDArray]],
+    name: str,
+    input_shape: List[int],
+    *,
+    cache_dir: Optional[str] = None,
+    opt_level: int = 3,
+    pass_config: Optional[Dict[str, Any]] = None,
+    disabled_pass: Optional[List[str]] = None,
+) -> List[ExtractedTask]:
+    """Extract the tasks from a network.
+
+    Parameters
+    ----------
+    mod : IRModule
+        The IRModule representing the network.
+    target : Target
+        The target that the network will be deployed to.
+    params : Optional[Dict[str, NDArray]]
+        The parameters of the networks.
+    name : str
+        The name of the network.
+    input_shape : List[int]
+        The shape of the input tensor.
+    cache_dir : Optional[str]
+        The directory to cache the generated network.
+        If not specified, the cache will be disabled.
+    opt_level : int
+        The optimization level of the compiler.
+    pass_config : Optional[Dict[str, Any]]
+        The pass config of the compiler.
+    disabled_pass : Optional[List[str]]
+        The disabled pass of the compiler.
+
+    Returns
+    -------
+    extracted_tasks : List[ExtractedTask]
+        The extracted tasks.
+    """
+    filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in 
input_shape)}.json'
+    extracted_tasks = _load_cache(cache_dir, filename)
+    if extracted_tasks is None:
+        extracted_tasks = extract_task_from_relay(
+            mod=mod,
+            target=target,
+            params=params,
+            opt_level=opt_level,
+            pass_config=pass_config,
+            disabled_pass=disabled_pass,
         )
-        mod = IRModule.from_expr(net)
-    return mod, params, input_shape, output_shape
+        extracted_tasks = list(extracted_tasks)
+        _save_cache(cache_dir, filename, extracted_tasks)
+    return extracted_tasks
+
+
+def _build_dataset() -> List[Tuple[str, List[int]]]:
+    network_keys = []
+    for name in [
+        "resnet_18",
+        "resnet_50",
+        "mobilenet_v2",
+        "mobilenet_v3",
+        "wide_resnet_50",
+        "resnext_50",
+        "densenet_121",
+        "vgg_16",
+    ]:
+        for batch_size in [1, 4, 8]:
+            for image_size in [224, 240, 256]:
+                network_keys.append((name, [batch_size, 3, image_size, 
image_size]))
+    # inception-v3
+    for name in ["inception_v3"]:
+        for batch_size in [1, 2, 4]:
+            for image_size in [299]:
+                network_keys.append((name, [batch_size, 3, image_size, 
image_size]))
+    # resnet3d
+    for name in ["resnet3d_18"]:
+        for batch_size in [1, 2, 4]:
+            for image_size in [112, 128, 144]:
+                network_keys.append((name, [batch_size, 3, image_size, 
image_size, 16]))
+    # bert
+    for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
+        for batch_size in [1, 2, 4]:
+            for seq_length in [64, 128, 256]:
+                network_keys.append((name, [batch_size, seq_length]))
+    # dcgan
+    for name in ["dcgan"]:
+        for batch_size in [1, 4, 8]:
+            for image_size in [64]:
+                network_keys.append((name, [batch_size, 3, image_size, 
image_size]))
+
+    return network_keys
+
+
+SUPPORTED = [
+    # TorchVision
+    "resnet_18",
+    "resnet_50",
+    "mobilenet_v2",
+    "mobilenet_v3",
+    "wide_resnet_50",
+    "resnext_50",
+    "resnet3d_18",
+    "inception_v3",
+    "densenet_121",
+    "vgg_16",
+    # Transformer
+    "bert_tiny",
+    "bert_base",
+    "bert_medium",
+    "bert_large",
+    # Relay testing
+    "dcgan",
+]
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 7872091..97f7adc 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -19,28 +19,31 @@ Construct the necessary state for the TVM graph executor
 from a Relay expression.
 """
 import warnings
-import numpy as np
 
+import numpy as np
 from tvm.ir import IRModule
-
 from tvm.ir.transform import PassContext
-from tvm.tir import expr as tvm_expr
 from tvm.target import Target
-from .. import nd as _nd, autotvm, register_func
+from tvm.tir import expr as tvm_expr
+
+from .. import autotvm
+from .. import nd as _nd
+from .. import register_func
+from ..contrib import graph_executor as _graph_executor
+from ..contrib import utils as contrib_utils
 from ..runtime import load_module
 from ..runtime.executor import aot_executor as _aot_executor
 from ..target import Target
-from ..contrib import graph_executor as _graph_executor
-from ..contrib import utils as contrib_utils
 from . import _build_module
-from . import ty as _ty
 from . import expr as _expr
 from . import function as _function
-from .transform import InferType
-from .backend.utils import mangle_module_name
-from .backend import executor_factory as _executor_factory, Executor, Runtime
+from . import ty as _ty
+from .backend import Executor, Runtime
+from .backend import executor_factory as _executor_factory
 from .backend import interpreter as _interpreter
+from .backend.utils import mangle_module_name
 from .backend.vm import VMExecutor
+from .transform import InferType
 
 
 def build_target_by_device_type_map(target):
@@ -287,13 +290,17 @@ def _module_export(module, file_name):  # fcompile, 
addons, kwargs?
 
 
 @register_func("tvm.relay.build")
+def _build_module_no_factory_impl(mod, target, target_host, params, mod_name):
+    target, target_host = Target.check_and_update_host_consist(target, 
target_host)
+    return build(mod, target, params=params, mod_name=mod_name).module
+
+
 def _build_module_no_factory(mod, target=None, target_host=None, params=None, 
mod_name="default"):
     """A wrapper around build which discards the Python GraphFactoryRuntime.
     This wrapper is suitable to be used from other programming languages as
     the runtime::Module can be freely passed between language boundaries.
     """
-    target, target_host = Target.check_and_update_host_consist(target, 
target_host)
-    return build(mod, target, params=params, mod_name=mod_name).module
+    return _build_module_no_factory_impl(mod, target, target_host, params, 
mod_name)
 
 
 def _reconstruct_from_deprecated_options(deprecated_params_target):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index b1230c0..dbff3f6 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -1052,6 +1052,12 @@ def _conv2d_winograd_nhwc_impl(
     )
 
     # transform data
+    target = tvm.target.Target.current(allow_none=True)
+    if target is not None:
+        target_kind = "meta_schedule.winograd_data_pack." + target.kind.name
+    else:
+        target_kind = "None"
+
     r_a = te.reduce_axis((0, alpha), "r_a")
     r_b = te.reduce_axis((0, alpha), "r_b")
     data_pack = te.compute(
@@ -1062,7 +1068,7 @@ def _conv2d_winograd_nhwc_impl(
         name="data_pack",
         attrs={
             "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", 
"r_a", "r_b"],
-            "schedule_rule": "meta_schedule.winograd_data_pack.cpu",
+            "schedule_rule": target_kind,
         },
         # the attrs are necessary hints for the auto-scheduler
     )
diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc
index 1ecb537..1ebec19 100644
--- a/src/meta_schedule/integration.cc
+++ b/src/meta_schedule/integration.cc
@@ -28,17 +28,24 @@ namespace meta_schedule {
 /**************** Utility functions ****************/
 
 template <class FunctionType>
-bool HasOnlyOneFunction(const IRModule& mod) {
+Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
   if (mod->functions.size() != 1) {
-    return false;
+    return NullOpt;
   }
   for (const auto& kv : mod->functions) {
     const BaseFunc& func = kv.second;
     if (!func->IsInstance<typename FunctionType::ContainerType>()) {
-      return false;
+      return NullOpt;
+    } else {
+      return Downcast<FunctionType>(func);
     }
   }
-  return true;
+  return NullOpt;
+}
+
+template <class FunctionType>
+bool HasOnlyOneFunction(const IRModule& mod) {
+  return GetOnlyOneFunction<FunctionType>(mod).defined();
 }
 
 /**************** ExtractedTask ****************/
@@ -129,14 +136,17 @@ Optional<ObjectRef> 
ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
   if (database->HasWorkload(prim_mod)) {
     Array<TuningRecord> records = 
database->GetTopK(database->CommitWorkload(prim_mod), 1);
     if (records.size() == 1) {
-      LOG(INFO) << "Applied history best for " << task_name << ".";
+      LOG(INFO) << "Applied history best for: " << task_name;
       tir::Schedule sch =
           tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, 
/*debug_mask=*/0,
                                 
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
       records[0]->trace->ApplyToSchedule(sch, false);
-      return sch->mod();
+      tir::PrimFunc func = 
GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
+      LOG(INFO) << "\n" << tir::AsTVMScript(func);
+      return func;
     }
   }
+  LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << 
tir::AsTVMScript(prim_mod);
   return NullOpt;
 }
 
diff --git a/src/meta_schedule/schedule_rule/winograd.cc 
b/src/meta_schedule/schedule_rule/winograd.cc
index 44db6f2..d8aab3a 100644
--- a/src/meta_schedule/schedule_rule/winograd.cc
+++ b/src/meta_schedule/schedule_rule/winograd.cc
@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse")
       return {sch};
     });
 
-TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cpu")
+TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm")
     .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
       BlockRV input_tile = GetOnlyProducer(sch, data_pack);
       BlockRV data_pad = GetOnlyProducer(sch, input_tile);
diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py 
b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
index 3b4164c..91e2c41 100644
--- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
+++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
@@ -15,37 +15,37 @@
 # specific language governing permissions and limitations
 # under the License.
 """ Test Meta Schedule Builder """
+# pylint: disable=missing-docstring
+
 import sys
+from typing import List
+
 import pytest
-import itertools
 import tvm
 from tvm import relay
+from tvm.meta_schedule.arg_info import TensorInfo
+from tvm.meta_schedule.builder import BuilderInput, LocalBuilder
+from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput
+from tvm.meta_schedule.testing.custom_builder_runner import (
+    build_relay,
+    build_relay_with_tensorrt,
+    run_with_graph_executor,
+)
+from tvm.meta_schedule.testing.relay_workload import get_network
 from tvm.relay import testing
 from tvm.relay.op.contrib import tensorrt
-import numpy as np
-from typing import List
-from tvm._ffi import register_func
 from tvm.target import Target
-from tvm.runtime import Module
-from tvm.meta_schedule.arg_info import TensorInfo
-from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult
-from tvm.meta_schedule.runner import (
-    EvaluatorConfig,
-    LocalRunner,
-    RunnerInput,
-)
-
 from tvm.tir import FloatImm
-from tvm.meta_schedule.testing import get_network
 
 has_tensorrt_codegen = pytest.mark.skipif(
-    not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT 
codegen not available"
+    not tvm.get_global_func("relay.ext.tensorrt", True),
+    reason="TensorRT codegen not available",
 )
 has_tensorrt_runtime = pytest.mark.skipif(
-    not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not 
available"
+    not tensorrt.is_tensorrt_runtime_enabled(),
+    reason="TensorRT runtime not available",
 )
 
-
 # conv2d+relu network
 def get_conv2d_relu(
     data_shape,
@@ -83,105 +83,52 @@ def get_conv2d_relu(
 
 
 def verify_meta_schedule_with_tensorrt(
-    mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = 
True, mode: str = "vm"
+    mod,
+    params,
+    data_shape,
+    use_trt: bool = True,
 ):
-    if use_meta_sched:
-        # With meta_schedule
-        dev = "cuda"
-
-        # Build
-        if use_trt:
-            from tvm.meta_schedule.testing import relay_build_with_tensorrt
-
-            builder = LocalBuilder(f_build=relay_build_with_tensorrt)
-        else:
-
-            def relay_build_without_tensorrt(
-                mod: Module,
-                target: Target,
-                params: dict,
-            ) -> List[BuilderResult]:
-                return tvm.relay.build_module._build_module_no_factory(mod, 
"cuda", "llvm", params)
-
-            builder = LocalBuilder(f_build=relay_build_without_tensorrt)
-
-        builder_input = BuilderInput(mod, Target(dev, host="llvm"), params)
-
-        (builder_result,) = builder.build([builder_input])
-        assert builder_result.error_msg is None
-        assert builder_result.artifact_path is not None
-
-        # Run
-        evaluator_config = EvaluatorConfig(
+    # Build
+    builder = LocalBuilder(
+        f_build=build_relay_with_tensorrt if use_trt else build_relay,
+        timeout_sec=1000,
+    )
+    builder_input = BuilderInput(mod, Target("cuda"), params)
+    builder_result = builder.build([builder_input])[0]
+    assert builder_result.error_msg is None, builder_result.error_msg
+    assert builder_result.artifact_path is not None
+
+    # Run
+    runner_input = RunnerInput(
+        builder_result.artifact_path,
+        device_type="cuda",
+        args_info=[TensorInfo("float32", data_shape)],
+    )
+    runner = LocalRunner(
+        evaluator_config=EvaluatorConfig(
             number=5,
             repeat=2,
             min_repeat_ms=0,
             enable_cpu_cache_flush=False,
-        )
-
-        runner_input = RunnerInput(
-            builder_result.artifact_path, "cuda", [TensorInfo("float32", 
data_shape)]
-        )
-
-        def eval_func(rt_mod, device, evaluator_config, repeated_args):
-            rt_mod = 
tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device))
-
-            eval = rt_mod.module.time_evaluator(
-                func_name="run",
-                dev=device,
-                number=evaluator_config.number,
-                repeat=evaluator_config.repeat,
-                min_repeat_ms=evaluator_config.min_repeat_ms,
-                f_preproc="cache_flush_cpu_non_first_arg"
-                if evaluator_config.enable_cpu_cache_flush
-                else "",
-            )
-            repeated_costs: List[List[float]] = []
-            for args in repeated_args:
-                profile_result = eval(*args)
-                repeated_costs.append(profile_result.results)
-
-            costs = [float(cost) for cost in 
itertools.chain.from_iterable(repeated_costs)]
-            return costs
-
-        runner = LocalRunner(
-            evaluator_config=evaluator_config,
-            f_run_evaluator=eval_func,
-        )
-
-        # Run the module
-        (runner_future,) = runner.run([runner_input])
-        runner_result = runner_future.result()
-        assert runner_result is not None
-        assert runner_result.run_secs is not None
-        assert runner_result.error_msg is None
-
-        for result in runner_result.run_secs:
-            if isinstance(result, FloatImm):
-                result = result.value
-            assert isinstance(result, float)
-            assert result >= 0.0
-
-    else:
-        # Without meta_schedule
-        if use_trt:
-            mod, config = tensorrt.partition_for_tensorrt(mod)
-            with tvm.transform.PassContext(
-                opt_level=3, config={"relay.ext.tensorrt.options": config}
-            ):
-                func = relay.create_executor(
-                    mode, mod=mod, device=tvm.cuda(0), target="cuda"
-                ).evaluate()
-        else:
-            with tvm.transform.PassContext(opt_level=3):
-                func = relay.create_executor(
-                    mode, mod=mod, device=tvm.cuda(0), target="cuda", 
params=params
-                ).evaluate()
-
-
[email protected]_cuda
+        ),
+        f_run_evaluator=run_with_graph_executor,
+    )
+
+    # Run the module
+    runner_future = runner.run([runner_input])[0]
+    runner_result = runner_future.result()
+    assert runner_result is not None
+    assert runner_result.error_msg is None, runner_result.error_msg
+    assert runner_result.run_secs is not None
+
+    for result in runner_result.run_secs:
+        if isinstance(result, FloatImm):
+            result = result.value
+        assert isinstance(result, float)
+        assert result >= 0.0
+
+
 @has_tensorrt_codegen
-@has_tensorrt_runtime
 def test_conv2d_relu():
     data_shape = (1, 1280, 14, 14)
     out_channels = 256
@@ -206,21 +153,17 @@ def test_conv2d_relu():
     verify_meta_schedule_with_tensorrt(mod, params, data_shape)
 
 
[email protected]_cuda
 @has_tensorrt_codegen
-@has_tensorrt_runtime
[email protected](
-    "model_name",
-    ["resnet-50", "mobilenet"],
-)
[email protected]("batch_size", [1])
[email protected]("use_meta_sched", [True])
[email protected]("model_name", ["resnet_50"])
[email protected]("input_shape", [[1, 3, 224, 224]])
 @pytest.mark.parametrize("use_trt", [True, False])
-def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, 
use_trt: bool):
-
-    mod, params, input_shape, output_shape = get_network(name=model_name, 
batch_size=batch_size)
+def test_relay_model(model_name: str, input_shape: List[int], use_trt: bool):
+    mod, params, _ = get_network(model_name, input_shape)
     verify_meta_schedule_with_tensorrt(
-        mod, params, input_shape, use_meta_sched=use_meta_sched, 
use_trt=use_trt, mode="vm"
+        mod,
+        params,
+        input_shape,
+        use_trt,
     )
 
 
diff --git a/tests/python/unittest/test_meta_schedule_integration.py 
b/tests/python/unittest/test_meta_schedule_integration.py
index 3676e3a..50dc928 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -18,30 +18,29 @@ import sys
 from typing import List
 
 import pytest
-
 import tvm
 from tvm import meta_schedule as ms
 from tvm.ir.module import IRModule
-from tvm.meta_schedule.utils import derived_object
-from tvm.tir import Schedule
-from tvm.target import Target
-from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
 from tvm.meta_schedule.integration import (
+    ApplyHistoryBest,
     ExtractedTask,
     MetaScheduleContext,
     TaskExtraction,
-    ApplyHistoryBest,
 )
-from tvm.meta_schedule.testing import get_network
+from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.utils import derived_object
 from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir import Schedule
 
-# pylint: 
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
+# pylint: 
disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name
 
 
 @tvm.script.ir_module
 class MockModule:
     @T.prim_func
-    def main(a: T.handle, b: T.handle) -> None:  # pylint: 
disable=no-self-argument
+    def main(a: T.handle, b: T.handle) -> None:  # type: ignore
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         A = T.match_buffer(a, (16,), "float32")
         B = T.match_buffer(b, (16,), "float32")
@@ -51,7 +50,17 @@ class MockModule:
                 B[vi] = A[vi]
 
 
-# pylint: 
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
+# pylint: 
enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument
+
+
+def _has_torch():
+    import importlib.util  # pylint: 
disable=unused-import,import-outside-toplevel
+
+    spec = importlib.util.find_spec("torch")
+    return spec is not None
+
+
+requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not 
installed")
 
 
 def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule):
@@ -63,13 +72,9 @@ def _check_mock_task(tasks: List[ExtractedTask], mod: 
IRModule):
     tvm.ir.assert_structural_equal(tir_mod, MockModule)
 
 
+@requires_torch
 def test_meta_schedule_integration_task_extraction_query():
-    mod, _, _, _ = get_network(
-        name="resnet-18",
-        batch_size=1,
-        layout="NHWC",
-        dtype="float32",
-    )
+    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
     env = TaskExtraction()
     env.query(task_name="mock-task", mod=mod, target=Target("llvm"), 
dispatched=[MockModule])
     _check_mock_task(env.tasks, mod)
@@ -93,13 +98,9 @@ def test_meta_schedule_integration_multiple_current():
                 ...
 
 
+@requires_torch
 def test_meta_schedule_integration_query_inside_with_scope():
-    mod, _, _, _ = get_network(
-        name="resnet-18",
-        batch_size=1,
-        layout="NHWC",
-        dtype="float32",
-    )
+    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
     env = TaskExtraction()
     with env:
         MetaScheduleContext.query_inside_with_scope(
@@ -111,17 +112,43 @@ def 
test_meta_schedule_integration_query_inside_with_scope():
     _check_mock_task(env.tasks, mod)
 
 
+@requires_torch
 def test_meta_schedule_integration_extract_from_resnet():
-    mod, params, _, _ = get_network(
-        name="resnet-18",
-        batch_size=1,
-        layout="NHWC",
-        dtype="float32",
-    )
+    mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 
224])
     extracted_tasks = ms.integration.extract_task_from_relay(mod, 
target="llvm", params=params)
-    assert len(extracted_tasks) == 30
-
-
+    expected_task_names = [
+        "vm_mod_fused_" + s
+        for s in [
+            "nn_max_pool2d",
+            "nn_adaptive_avg_pool2d",
+            "nn_dense_add",
+            "nn_conv2d_add",
+            "nn_conv2d_add_1",
+            "nn_conv2d_add_2",
+            "nn_conv2d_add_add_nn_relu",
+            "nn_conv2d_add_add_nn_relu_1",
+            "nn_conv2d_add_nn_relu",
+            "nn_conv2d_add_nn_relu_1",
+            "nn_conv2d_add_nn_relu_2",
+            "nn_conv2d_add_nn_relu_3",
+            "nn_conv2d_add_nn_relu_4",
+            "nn_conv2d_add_nn_relu_5",
+            
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
+            
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
+            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
+            
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
+            # The two tasks below are purely spatial and are ruled out by 
AutoScheduler
+            "layout_transform",
+            "layout_transform_reshape_squeeze",
+        ]
+    ]
+
+    assert len(extracted_tasks) == 20
+    for t in extracted_tasks:
+        assert t.task_name in expected_task_names, t.task_name
+
+
+@requires_torch
 def test_meta_schedule_integration_apply_history_best():
     @derived_object
     class DummyDatabase(PyDatabase):
@@ -161,12 +188,7 @@ def test_meta_schedule_integration_apply_history_best():
         def print_results(self) -> None:
             print("\n".join([str(r) for r in self.records]))
 
-    mod, _, _, _ = get_network(
-        name="resnet-18",
-        batch_size=1,
-        layout="NHWC",
-        dtype="float32",
-    )
+    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
     database = DummyDatabase()
     env = ApplyHistoryBest(database)
     target = Target("llvm")
@@ -175,6 +197,7 @@ def test_meta_schedule_integration_apply_history_best():
         TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
     )
     mod = env.query(task_name="mock-task", mod=mod, target=target, 
dispatched=[MockModule])
+    mod = IRModule({"main": mod})
     assert tvm.ir.structural_equal(mod, workload.mod)
 
 
diff --git a/tests/python/unittest/test_meta_schedule_runner.py 
b/tests/python/unittest/test_meta_schedule_runner.py
index c2fb92c..09e708f 100644
--- a/tests/python/unittest/test_meta_schedule_runner.py
+++ b/tests/python/unittest/test_meta_schedule_runner.py
@@ -23,8 +23,8 @@ from typing import Any, List
 
 import numpy as np
 import pytest
-
 import tvm
+import tvm.testing
 from tvm._ffi import register_func
 from tvm.meta_schedule.arg_info import TensorInfo
 from tvm.meta_schedule.builder import BuilderInput, LocalBuilder
@@ -37,21 +37,25 @@ from tvm.meta_schedule.runner import (
     RunnerFuture,
     RunnerInput,
 )
+from tvm.meta_schedule.runner.local_runner import (
+    default_alloc_argument as local_default_alloc_argument,
+)
 from tvm.meta_schedule.runner.rpc_runner import (
-    default_alloc_argument as rpc_default_alloc_argument,
     T_ARG_INFO_JSON_OBJ_LIST,
     T_ARGUMENT_LIST,
 )
-from tvm.meta_schedule.runner.local_runner import (
-    default_alloc_argument as local_default_alloc_argument,
+from tvm.meta_schedule.runner.rpc_runner import (
+    default_alloc_argument as rpc_default_alloc_argument,
+)
+from tvm.meta_schedule.testing.local_rpc import LocalRPC
+from tvm.meta_schedule.utils import (
+    derived_object,
+    get_global_func_with_default_on_worker,
 )
-from tvm.meta_schedule.testing import LocalRPC
-from tvm.meta_schedule.utils import derived_object, 
get_global_func_with_default_on_worker
 from tvm.rpc import RPCSession
 from tvm.runtime import Device, Module
 from tvm.script import tir as T
 from tvm.target import Target
-import tvm.testing
 from tvm.tir import FloatImm
 
 MATMUL_N = 16
@@ -886,4 +890,4 @@ def test_meta_schedule_local_runner_add_test():
 
 
 if __name__ == "__main__":
-    test_meta_schedule_local_single_run()
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py 
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index 1443110..6bf59d2 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -17,27 +17,20 @@
 # pylint: disable=missing-docstring
 import logging
 import tempfile
-import pytest
-import numpy as np
-from typing import Tuple, List
-
-from tvm.meta_schedule.utils import derived_object
-
-try:
-    import torch
-except ModuleNotFoundError:
-    pass
+from typing import List
 
+import numpy as np
+import pytest
 import tvm
 from tvm import relay
-from tvm.ir import IRModule
-from tvm.runtime.ndarray import cpu, cuda
-from tvm.target.target import Target
 from tvm.contrib import graph_executor
+from tvm.ir import IRModule
 from tvm.meta_schedule import ReplayTraceConfig
-from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
-from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
+from tvm.meta_schedule.testing.relay_workload import get_network
 from tvm.meta_schedule.tune import tune_relay
+from tvm.meta_schedule.utils import derived_object
+from tvm.target.target import Target
 
 logging.basicConfig()
 logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
@@ -83,47 +76,33 @@ class DummyDatabase(PyDatabase):
 
 
 @pytest.mark.skip("Integration test")
[email protected]("model_name", ["resnet18", "mobilenet_v2", 
"bert_base"])
[email protected]("batch_size", [1])
[email protected]("target", ["llvm --num-cores=16", 
"nvidia/geforce-rtx-3070"])
-def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: 
str):
-    if model_name == "inception_v3" and batch_size == 1:
-        pytest.skip("inception_v3 does not handle batch_size of 1")
-
-    input_shape: Tuple[int, ...]
-    input_name = "input0"
-    dev = tvm.cpu() if str(target).startswith("llvm") else cuda()
-    if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
-        seq_length = 128
-        input_name = "input_ids"
-        input_shape = (batch_size, seq_length)
[email protected](
+    "model_name, input_shape, target",
+    [
+        ("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16"),
+        ("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"),
+        ("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16"),
+        ("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"),
+        ("bert_base", [1, 64], "llvm --num-cores=16"),
+        ("bert_base", [1, 64], "nvidia/geforce-rtx-3070"),
+    ],
+)
+def test_meta_schedule_tune_relay(
+    model_name: str,
+    input_shape: List[int],
+    target: str,
+):
+    dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda()
+    if model_name.startswith("bert"):
         data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), 
dev)  # embedding size
     else:
-        if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
-            input_shape = (batch_size, 3, 299, 299)
-        elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION:
-            input_shape = (batch_size, 3, 299, 299)
-        elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION:
-            input_shape = (1, 3, 300, 300)
-        elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
-            input_shape = (batch_size, 3, 3, 299, 299)
-        else:
-            raise ValueError("Unsupported model: " + model_name)
         data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), 
dev)
 
-    output_shape: Tuple[int, int] = (batch_size, 1000)
-
-    mod, params = get_torch_model(
-        model_name=model_name,
-        input_shape=input_shape,
-        output_shape=output_shape,
-        dtype="float32",
-    )
-
+    mod, params, (input_name, _, _) = get_network(name=model_name, 
input_shape=input_shape)
+    target = Target(target)
     with tempfile.TemporaryDirectory() as work_dir:
-        target = Target(target)
         database = DummyDatabase()
-        rt_mod: tvm.module = tune_relay(
+        rt_mod: tvm.runtime.Module = tune_relay(
             mod=mod,
             params=params,
             target=target,
@@ -136,7 +115,7 @@ def test_meta_schedule_tune_relay(model_name: str, 
batch_size: int, target: str)
         )
         # Compile without meta-scheduler for correctness check
         with tvm.transform.PassContext(opt_level=0):
-            rt_mod2 = relay.build(mod, target=target, params=params)
+            rt_mod2 = relay.build(mod, target=Target("llvm"), params=params)
 
         def get_output(data, lib):
             module = graph_executor.GraphModule(lib["default"](dev))
@@ -146,14 +125,14 @@ def test_meta_schedule_tune_relay(model_name: str, 
batch_size: int, target: str)
 
         # Check correctness
         actual_output = get_output(data, rt_mod)
-        expected_output = get_output(data, rt_mod2)
+        expected_output = get_output(tvm.nd.array(data.numpy(), 
device=tvm.cpu()), rt_mod2)
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
 if __name__ == """__main__""":
-    test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
-    test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
-    test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16")
-    test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
-    test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")
-    test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")
+    test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm 
--num-cores=16")
+    test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], 
"nvidia/geforce-rtx-3070")
+    test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm 
--num-cores=16")
+    test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], 
"nvidia/geforce-rtx-3070")
+    test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16")
+    test_meta_schedule_tune_relay("bert_base", [1, 64], 
"nvidia/geforce-rtx-3070")

Reply via email to