This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 085d36c [MetaSchedule] Refactor testing workloads (#10497)
085d36c is described below
commit 085d36ca7f86fec31af78a2d7ec924266bf7b147
Author: Junru Shao <[email protected]>
AuthorDate: Sat Mar 5 17:41:50 2022 -0800
[MetaSchedule] Refactor testing workloads (#10497)
---
python/tvm/auto_scheduler/relay_integration.py | 12 +-
python/tvm/meta_schedule/integration.py | 15 +-
python/tvm/meta_schedule/testing/__init__.py | 3 -
python/tvm/meta_schedule/testing/byoc_trt.py | 53 ---
.../meta_schedule/testing/conv2d_winograd_cpu.py | 2 +-
.../meta_schedule/testing/custom_builder_runner.py | 140 +++++++
python/tvm/meta_schedule/testing/relay_workload.py | 435 +++++++++++++++------
python/tvm/relay/build_module.py | 31 +-
python/tvm/topi/nn/conv2d.py | 8 +-
src/meta_schedule/integration.cc | 22 +-
src/meta_schedule/schedule_rule/winograd.cc | 2 +-
.../unittest/test_meta_schedule_byoc_tensorrt.py | 187 +++------
.../unittest/test_meta_schedule_integration.py | 97 +++--
tests/python/unittest/test_meta_schedule_runner.py | 20 +-
.../unittest/test_meta_schedule_tune_relay.py | 95 ++---
15 files changed, 683 insertions(+), 439 deletions(-)
diff --git a/python/tvm/auto_scheduler/relay_integration.py
b/python/tvm/auto_scheduler/relay_integration.py
index e9bb68a..7ff1840 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs):
"""
# pylint: disable=import-outside-toplevel
- from tvm.auto_scheduler.measure import (
+ from tvm.auto_scheduler.measure import ( # lazily import to avoid
recursive dependency
prepare_input_map,
- ) # lazily import to avoid recursive dependency
+ )
io_tensors, has_layout_free, has_complex_op =
traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not
supported yet.
@@ -482,4 +482,10 @@ def is_auto_scheduler_enabled():
enabled: bool
Whether the auto-scheduler is enabled
"""
- return
PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
+ return PassContext.current().config.get(
+ "relay.backend.use_auto_scheduler",
+ False,
+ ) or PassContext.current().config.get(
+ "relay.backend.use_meta_schedule",
+ False,
+ )
diff --git a/python/tvm/meta_schedule/integration.py
b/python/tvm/meta_schedule/integration.py
index 178239e..26b0144 100644
--- a/python/tvm/meta_schedule/integration.py
+++ b/python/tvm/meta_schedule/integration.py
@@ -204,10 +204,8 @@ def extract_task_from_relay(
params: Optional[Dict[str, NDArray]] = None,
*,
opt_level: int = 3,
- pass_config: Dict[str, Any] = {
- "relay.backend.use_meta_schedule": True,
- },
- disabled_pass: List[str] = [],
+ pass_config: Optional[Dict[str, Any]] = None,
+ disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relay program.
@@ -221,9 +219,9 @@ def extract_task_from_relay(
The associated parameters of the program
opt_level : int
The optimization level of the compiler
- pass_config : Dict[str, Any]
+ pass_config : Optional[Dict[str, Any]]
The pass config of the compiler
- disabled_pass : List[str]
+ disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
Returns
@@ -250,6 +248,11 @@ def extract_task_from_relay(
thread.start()
thread.join()
+ if disabled_pass is None:
+ disabled_pass = []
+ if pass_config is None:
+ pass_config = {"relay.backend.use_meta_schedule": True}
+
env = TaskExtraction()
if isinstance(mod, RelayFunc):
mod = IRModule.from_expr(mod)
diff --git a/python/tvm/meta_schedule/testing/__init__.py
b/python/tvm/meta_schedule/testing/__init__.py
index 85b48b3..5d6081f 100644
--- a/python/tvm/meta_schedule/testing/__init__.py
+++ b/python/tvm/meta_schedule/testing/__init__.py
@@ -15,6 +15,3 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
-from .byoc_trt import relay_build_with_tensorrt
-from .local_rpc import LocalRPC
-from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network,
get_torch_model
diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py
b/python/tvm/meta_schedule/testing/byoc_trt.py
deleted file mode 100644
index d459518..0000000
--- a/python/tvm/meta_schedule/testing/byoc_trt.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TensorRT-MetaSchedule integration"""
-# pylint: disable=import-outside-toplevel
-
-from typing import List
-import tvm
-from tvm.runtime import Module
-from tvm.meta_schedule.builder import BuilderResult
-from tvm.target import Target
-
-
-def relay_build_with_tensorrt(
- mod: Module,
- target: Target,
- params: dict,
-) -> List[BuilderResult]:
- """Build a Relay IRModule with TensorRT BYOC
- Parameters
- ----------
- mod : IRModule
- The Relay IRModule to build.
- target : Target
- The target to build the module for.
- params : Dict[str, NDArray]
- The parameter dict to build the module with.
- Returns
- -------
- mod : runtime.Module
- The built module.
- """
- from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
-
- assert isinstance(target, Target)
- mod, config = partition_for_tensorrt(mod, params)
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.tensorrt.options": config}):
- result = tvm.relay.build_module._build_module_no_factory(mod, "cuda",
"llvm", params)
- assert isinstance(result, Module)
- return result
diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
index bfd5f45..261768c 100644
--- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
+++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
@@ -79,7 +79,7 @@ def conv2d_winograd_cpu(
eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
"SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
)
- T.block_attr({"schedule_rule":
"meta_schedule.winograd_data_pack.cpu"})
+ T.block_attr({"schedule_rule":
"meta_schedule.winograd_data_pack.llvm"})
T.reads(
[
data_pack[eps_1, nu_1, p_1, ci_1],
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
new file mode 100644
index 0000000..87bad5a
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Customized builder and runner methods"""
+# pylint: disable=import-outside-toplevel
+
+from typing import TYPE_CHECKING, Dict, List
+
+if TYPE_CHECKING:
+ from tvm.ir import IRModule
+ from tvm.meta_schedule.runner import EvaluatorConfig
+ from tvm.runtime import Device, Module, NDArray
+ from tvm.target import Target
+
+
+def build_relay(
+ mod: "IRModule",
+ target: "Target",
+ params: Dict[str, "NDArray"],
+) -> "Module":
+ """Build a Relay IRModule
+
+ Parameters
+ ----------
+ mod : IRModule
+ The Relay IRModule to build.
+ target : Target
+ The target to build the module for.
+ params : Dict[str, NDArray]
+ The parameter dict to build the module with.
+
+ Returns
+ -------
+ mod : runtime.Module
+ The built module.
+ """
+ from tvm.relay.build_module import _build_module_no_factory as relay_build
+ from tvm.runtime import Module
+
+ result = relay_build(mod, target=target, target_host=None, params=params)
+ assert isinstance(result, Module)
+ return result
+
+
+def build_relay_with_tensorrt(
+ mod: "IRModule",
+ target: "Target",
+ params: Dict[str, "NDArray"],
+) -> "Module":
+ """Build a Relay IRModule with TensorRT BYOC
+
+ Parameters
+ ----------
+ mod : IRModule
+ The Relay IRModule to build.
+
+ target : Target
+ The target to build the module for.
+
+ params : Dict[str, NDArray]
+ The parameter dict to build the module with.
+
+ Returns
+ -------
+ mod : runtime.Module
+ The built module.
+ """
+ from tvm.ir.transform import PassContext
+ from tvm.relay.build_module import _build_module_no_factory as relay_build
+ from tvm.relay.op.contrib import tensorrt
+ from tvm.runtime import Module
+
+ mod, config = tensorrt.partition_for_tensorrt(mod, params)
+ with PassContext(
+ opt_level=3,
+ config={"relay.ext.tensorrt.options": config},
+ ):
+ result = relay_build(mod, target=target, target_host=None,
params=params)
+ assert isinstance(result, Module)
+ return result
+
+
+def run_with_graph_executor(
+ rt_mod: "Module",
+ device: "Device",
+ evaluator_config: "EvaluatorConfig",
+ repeated_args: List["NDArray"],
+) -> List[float]:
+ """Run a Relay module with GraphExecutor
+
+ Parameters
+ ----------
+ rt_mod : Module
+ The Relay module to run.
+ device : Device
+ The device to run the module on.
+ evaluator_config : EvaluatorConfig
+ The evaluator configuration to run the module with.
+ repeated_args : List[NDArray]
+ The list of repeated arguments to run the module with.
+
+ Returns
+ -------
+ results : List[float]
+ The list of results.
+ """
+ import itertools
+
+ from tvm.contrib.graph_executor import GraphModule
+
+ graph_mod = GraphModule(rt_mod["default"](device))
+ evaluator = graph_mod.module.time_evaluator(
+ func_name="run",
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+ repeated_costs = []
+ for args in repeated_args:
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in
itertools.chain.from_iterable(repeated_costs)]
+ return costs
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py
b/python/tvm/meta_schedule/testing/relay_workload.py
index 2f1ffdd..29cc70a 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -15,154 +15,333 @@
# specific language governing permissions and limitations
# under the License.
"""Workloads in Relay IR"""
-from enum import Enum
-from typing import Dict, Tuple
+# pylint: disable=import-outside-toplevel
+import multiprocessing
+import os
+import pickle
+from typing import Any, Dict, List, Optional, Tuple
-import tvm.relay.testing # pylint: disable=unused-import
+import tvm
+import tvm.relay.testing
from tvm import relay
from tvm.ir import IRModule
-from tvm.runtime import NDArray
-
-# Model types supported in Torchvision
-class MODEL_TYPE(Enum): # pylint: disable=invalid-name
- IMAGE_CLASSIFICATION = (1,)
- VIDEO_CLASSIFICATION = (2,)
- SEGMENTATION = (3,)
- OBJECT_DETECTION = (4,)
- TEXT_CLASSIFICATION = (5,)
-
-
-# Specify the type of each model
-MODEL_TYPES = {
- "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
- "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
- "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
-}
-
-
-def get_torch_model(
- model_name: str,
- input_shape: Tuple[int, ...],
- output_shape: Tuple[int, int], # pylint: disable=unused-argument
- dtype: str = "float32",
-) -> Tuple[IRModule, Dict[str, NDArray]]:
- """Load model from torch model zoo
- Parameters
- ----------
- model_name : str
- The name of the model to load
- input_shape: Tuple[int, ...]
- Tuple for input shape
- output_shape: Tuple[int, int]
- Tuple for output shape
- dtype: str
- Tensor data type
- """
+from tvm.meta_schedule.integration import ExtractedTask,
extract_task_from_relay
+from tvm.runtime import NDArray, load_param_dict, save_param_dict
+from tvm.target import Target
- assert dtype == "float32"
- import torch # type: ignore # pylint:
disable=import-error,import-outside-toplevel
- from torchvision import models # type: ignore # pylint:
disable=import-error,import-outside-toplevel
- import transformers # type: ignore # pylint:
disable=import-error,import-outside-toplevel
- import os # type: ignore # pylint:
disable=import-error,import-outside-toplevel
+def _get_network(
+ args: Tuple[str, List[int]]
+) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]:
+ name: str
+ input_shape: List[int]
+ name, input_shape = args
- def do_trace(model, inp):
- model.eval()
- model_trace = torch.jit.trace(model, inp)
- model_trace.eval()
- return model_trace
+ mod: IRModule
+
+ if name in [
+ "resnet_18",
+ "resnet_50",
+ "wide_resnet_50",
+ "resnext_50",
+ "mobilenet_v2",
+ "mobilenet_v3",
+ "inception_v3",
+ "densenet_121",
+ "resnet3d_18",
+ "vgg_16",
+ ]:
+ import torch # type: ignore
+ from torchvision import models # type: ignore
+
+ if name in ["resnet_18", "resnet_50"]:
+ model = getattr(models, name.replace("_", ""))(pretrained=False)
+ elif name == "wide_resnet_50":
+ model = getattr(models, "wide_resnet50_2")(pretrained=False)
+ elif name == "resnext_50":
+ model = getattr(models, "resnext50_32x4d")(pretrained=False)
+ elif name == "mobilenet_v2":
+ model = getattr(models, name)(pretrained=False)
+ elif name == "mobilenet_v3":
+ model = getattr(models, name + "_large")(pretrained=False)
+ elif name == "inception_v3":
+ model = getattr(models, name)(pretrained=False, aux_logits=False)
+ elif name == "densenet_121":
+ model = getattr(models, name.replace("_", ""))(pretrained=False)
+ elif name == "resnet3d_18":
+ model = models.video.r3d_18(pretrained=False)
+ elif name == "vgg_16":
+ model = getattr(models, name.replace("_", ""))(pretrained=False)
- # Load model from torchvision
- if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
+ dtype = "float32"
+ input_data = torch.randn(input_shape).type( # pylint:
disable=no-member
+ {
+ "float32": torch.float32, # pylint: disable=no-member
+ }[dtype]
+ )
+ scripted_model = torch.jit.trace(model, input_data).eval()
+ input_name = "input0"
+ shape_list = [(input_name, input_shape)]
+ mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+ with tvm.transform.PassContext(opt_level=3):
+ mod = tvm.transform.Sequential(
+ [
+ relay.transform.RemoveUnusedFunctions(),
+ relay.transform.ConvertLayout(
+ {
+ "nn.conv2d": ["NHWC", "default"],
+ "nn.conv3d": ["NDHWC", "default"],
+ "nn.max_pool2d": ["NHWC", "default"],
+ "nn.avg_pool2d": ["NHWC", "default"],
+ }
+ ),
+ ]
+ )(mod)
+ inputs = (input_name, input_shape, dtype)
+ elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
- model = transformers.BertModel(
- transformers.BertConfig(
+ # pip3 install transformers==3.5 torch==1.7
+ import torch # type: ignore
+ import transformers # type: ignore
+
+ config_dict = {
+ "bert_tiny": transformers.BertConfig(
+ num_hidden_layers=6,
+ hidden_size=512,
+ intermediate_size=2048,
+ num_attention_heads=8,
+ return_dict=False,
+ ),
+ "bert_base": transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
- )
- )
+ ),
+ "bert_medium": transformers.BertConfig(
+ num_hidden_layers=12,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_attention_heads=16,
+ return_dict=False,
+ ),
+ "bert_large": transformers.BertConfig(
+ num_hidden_layers=24,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_attention_heads=16,
+ return_dict=False,
+ ),
+ }
+ configuration = config_dict[name]
+ model = transformers.BertModel(configuration)
+ input_name = "input_ids"
+ input_dtype = "int64"
+ a = torch.randint(10000, input_shape) # pylint: disable=no-member
model.eval()
- input_data = torch.randint(10000, input_shape)
- shape_list = [("input_ids", input_shape)]
- scripted_model = torch.jit.trace(model, [input_data], strict=False)
- elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
- model = getattr(models, model_name)()
- # Setup input
- input_data = torch.randn(input_shape).type(torch.float32)
- shape_list = [("input0", input_shape)]
- # Get trace. Depending on the model type, wrapper may be necessary.
- scripted_model = do_trace(model, input_data)
+ scripted_model = torch.jit.trace(model, [a], strict=False)
+ input_name = "input_ids"
+ shape_list = [(input_name, input_shape)]
+ mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+ mod = relay.transform.FastMath()(mod)
+ mod = relay.transform.CombineParallelBatchMatmul()(mod)
+ inputs = (input_name, input_shape, input_dtype)
+ elif name == "dcgan":
+ output_shape = input_shape
+ batch_size = output_shape[0]
+ oshape = output_shape[1:]
+ mod, params = relay.testing.dcgan.get_workload(
+ batch_size=batch_size,
+ oshape=oshape,
+ layout="NHWC",
+ )
+ inputs = ("data", [100], "float32")
else:
- raise ValueError("Unsupported model in Torch model zoo.")
+ raise ValueError("Invalid name: " + name)
+
+ params_bytearray: bytearray = save_param_dict(params)
+ return mod, params_bytearray, inputs
+
+
+def _load_cache(cache_dir: Optional[str], filename: str) ->
Optional[List[Any]]:
+ if cache_dir is None:
+ return None
+ path = os.path.join(os.path.expanduser(cache_dir), filename)
+ if not os.path.exists(path):
+ return None
+ print(f"Load from cache: {path}")
+ with open(path, "rb") as i_f:
+ return pickle.load(i_f)
- # Convert torch model to relay module
- mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
- return mod, params
+
+def _save_cache(cache_dir: Optional[str], filename: str, objects: List[Any])
-> None:
+ if cache_dir is None:
+ return
+ path = os.path.join(os.path.expanduser(cache_dir), filename)
+ with open(path, "wb") as o_f:
+ pickle.dump(objects, o_f)
def get_network(
name: str,
- batch_size: int,
- layout: str = "NHWC",
- dtype: str = "float32",
-) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int,
int]]:
- """Get the symbol definition and random weight of a network"""
- # meta-schedule prefers NHWC layout
- if layout == "NHWC":
- image_shape = (224, 224, 3)
- elif layout == "NCHW":
- image_shape = (3, 224, 224)
- else:
- raise ValueError("Invalid layout: " + layout)
+ input_shape: List[int],
+ *,
+ cache_dir: Optional[str] = None,
+) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]:
+ """Get the symbol definition and random weight of a network
- input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape
- output_shape: Tuple[int, int] = (batch_size, 1000)
+ Parameters
+ ----------
+ name : str
+ The name of the network.
+ input_shape : List[int]
+ The shape of the input tensor.
+ cache_dir : Optional[str], optional
+ The directory to cache the generated network.
+ If not specified, the cache will be disabled.
- if name.startswith("resnet-"):
- n_layer = int(name.split("-")[1])
- mod, params = relay.testing.resnet.get_workload(
- num_layers=n_layer,
- batch_size=batch_size,
- layout=layout,
- dtype=dtype,
- image_shape=image_shape,
- )
- elif name.startswith("resnet3d-"):
- n_layer = int(name.split("-")[1])
- mod, params = relay.testing.resnet.get_workload(
- num_layers=n_layer,
- batch_size=batch_size,
- layout=layout,
- dtype=dtype,
- image_shape=image_shape,
- )
- elif name == "mobilenet":
- mod, params = relay.testing.mobilenet.get_workload(
- batch_size=batch_size, layout=layout, dtype=dtype,
image_shape=image_shape
- )
- elif name == "squeezenet_v1.1":
- assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
- mod, params = relay.testing.squeezenet.get_workload(
- version="1.1",
- batch_size=batch_size,
- dtype=dtype,
- image_shape=image_shape,
- )
- elif name == "inception_v3":
- input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else
(batch_size, 299, 299, 3)
- mod, params =
relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
- elif name == "mxnet":
- from mxnet.gluon.model_zoo.vision import get_model # type: ignore #
pylint: disable=import-outside-toplevel
-
- assert layout == "NCHW"
- block = get_model("resnet50_v1", pretrained=True)
- mod, params = relay.frontend.from_mxnet(block, shape={"data":
input_shape}, dtype=dtype)
- net = mod["main"]
- net = relay.Function(
- net.params, relay.nn.softmax(net.body), None, net.type_params,
net.attrs
+ Returns
+ -------
+ mod : IRModule
+ The IRModule representing the network.
+ params : Dict[str, NDArray]
+ The parameters of the networks.
+ inputs : Tuple[str, List[int], str]
+ The name, shape and dtype of the input tensor.
+ """
+
+ mod: IRModule
+ params: Dict[str, NDArray]
+ inputs: Tuple[str, List[int], str]
+ params_bytearray: bytearray
+
+ filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
+ cached = _load_cache(cache_dir, filename)
+ if cached is None:
+ with multiprocessing.Pool(processes=1) as pool:
+ result = pool.map(_get_network, [(name, input_shape)])
+ ((mod, params_bytearray, inputs),) = result
+ cached = [mod, params_bytearray, inputs]
+ _save_cache(cache_dir, filename, cached)
+ mod, params_bytearray, inputs = cached
+ params = load_param_dict(params_bytearray)
+ return mod, params, inputs
+
+
+def extract_from_relay(
+ mod: IRModule,
+ target: Target,
+ params: Optional[Dict[str, NDArray]],
+ name: str,
+ input_shape: List[int],
+ *,
+ cache_dir: Optional[str] = None,
+ opt_level: int = 3,
+ pass_config: Optional[Dict[str, Any]] = None,
+ disabled_pass: Optional[List[str]] = None,
+) -> List[ExtractedTask]:
+ """Extract the tasks from a network.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule representing the network.
+ target : Target
+ The target that the network will be deployed to.
+ params : Optional[Dict[str, NDArray]]
+ The parameters of the networks.
+ name : str
+ The name of the network.
+ input_shape : List[int]
+ The shape of the input tensor.
+ cache_dir : Optional[str]
+ The directory to cache the generated network.
+ If not specified, the cache will be disabled.
+ opt_level : int
+ The optimization level of the compiler.
+ pass_config : Optional[Dict[str, Any]]
+ The pass config of the compiler.
+ disabled_pass : Optional[List[str]]
+ The disabled pass of the compiler.
+
+ Returns
+ -------
+ extracted_tasks : List[ExtractedTask]
+ The extracted tasks.
+ """
+ filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in
input_shape)}.json'
+ extracted_tasks = _load_cache(cache_dir, filename)
+ if extracted_tasks is None:
+ extracted_tasks = extract_task_from_relay(
+ mod=mod,
+ target=target,
+ params=params,
+ opt_level=opt_level,
+ pass_config=pass_config,
+ disabled_pass=disabled_pass,
)
- mod = IRModule.from_expr(net)
- return mod, params, input_shape, output_shape
+ extracted_tasks = list(extracted_tasks)
+ _save_cache(cache_dir, filename, extracted_tasks)
+ return extracted_tasks
+
+
+def _build_dataset() -> List[Tuple[str, List[int]]]:
+ network_keys = []
+ for name in [
+ "resnet_18",
+ "resnet_50",
+ "mobilenet_v2",
+ "mobilenet_v3",
+ "wide_resnet_50",
+ "resnext_50",
+ "densenet_121",
+ "vgg_16",
+ ]:
+ for batch_size in [1, 4, 8]:
+ for image_size in [224, 240, 256]:
+ network_keys.append((name, [batch_size, 3, image_size,
image_size]))
+ # inception-v3
+ for name in ["inception_v3"]:
+ for batch_size in [1, 2, 4]:
+ for image_size in [299]:
+ network_keys.append((name, [batch_size, 3, image_size,
image_size]))
+ # resnet3d
+ for name in ["resnet3d_18"]:
+ for batch_size in [1, 2, 4]:
+ for image_size in [112, 128, 144]:
+ network_keys.append((name, [batch_size, 3, image_size,
image_size, 16]))
+ # bert
+ for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
+ for batch_size in [1, 2, 4]:
+ for seq_length in [64, 128, 256]:
+ network_keys.append((name, [batch_size, seq_length]))
+ # dcgan
+ for name in ["dcgan"]:
+ for batch_size in [1, 4, 8]:
+ for image_size in [64]:
+ network_keys.append((name, [batch_size, 3, image_size,
image_size]))
+
+ return network_keys
+
+
+SUPPORTED = [
+ # TorchVision
+ "resnet_18",
+ "resnet_50",
+ "mobilenet_v2",
+ "mobilenet_v3",
+ "wide_resnet_50",
+ "resnext_50",
+ "resnet3d_18",
+ "inception_v3",
+ "densenet_121",
+ "vgg_16",
+ # Transformer
+ "bert_tiny",
+ "bert_base",
+ "bert_medium",
+ "bert_large",
+ # Relay testing
+ "dcgan",
+]
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 7872091..97f7adc 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -19,28 +19,31 @@ Construct the necessary state for the TVM graph executor
from a Relay expression.
"""
import warnings
-import numpy as np
+import numpy as np
from tvm.ir import IRModule
-
from tvm.ir.transform import PassContext
-from tvm.tir import expr as tvm_expr
from tvm.target import Target
-from .. import nd as _nd, autotvm, register_func
+from tvm.tir import expr as tvm_expr
+
+from .. import autotvm
+from .. import nd as _nd
+from .. import register_func
+from ..contrib import graph_executor as _graph_executor
+from ..contrib import utils as contrib_utils
from ..runtime import load_module
from ..runtime.executor import aot_executor as _aot_executor
from ..target import Target
-from ..contrib import graph_executor as _graph_executor
-from ..contrib import utils as contrib_utils
from . import _build_module
-from . import ty as _ty
from . import expr as _expr
from . import function as _function
-from .transform import InferType
-from .backend.utils import mangle_module_name
-from .backend import executor_factory as _executor_factory, Executor, Runtime
+from . import ty as _ty
+from .backend import Executor, Runtime
+from .backend import executor_factory as _executor_factory
from .backend import interpreter as _interpreter
+from .backend.utils import mangle_module_name
from .backend.vm import VMExecutor
+from .transform import InferType
def build_target_by_device_type_map(target):
@@ -287,13 +290,17 @@ def _module_export(module, file_name): # fcompile,
addons, kwargs?
@register_func("tvm.relay.build")
+def _build_module_no_factory_impl(mod, target, target_host, params, mod_name):
+ target, target_host = Target.check_and_update_host_consist(target,
target_host)
+ return build(mod, target, params=params, mod_name=mod_name).module
+
+
def _build_module_no_factory(mod, target=None, target_host=None, params=None,
mod_name="default"):
"""A wrapper around build which discards the Python GraphFactoryRuntime.
This wrapper is suitable to be used from other programming languages as
the runtime::Module can be freely passed between language boundaries.
"""
- target, target_host = Target.check_and_update_host_consist(target,
target_host)
- return build(mod, target, params=params, mod_name=mod_name).module
+ return _build_module_no_factory_impl(mod, target, target_host, params,
mod_name)
def _reconstruct_from_deprecated_options(deprecated_params_target):
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index b1230c0..dbff3f6 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -1052,6 +1052,12 @@ def _conv2d_winograd_nhwc_impl(
)
# transform data
+ target = tvm.target.Target.current(allow_none=True)
+ if target is not None:
+ target_kind = "meta_schedule.winograd_data_pack." + target.kind.name
+ else:
+ target_kind = "None"
+
r_a = te.reduce_axis((0, alpha), "r_a")
r_b = te.reduce_axis((0, alpha), "r_b")
data_pack = te.compute(
@@ -1062,7 +1068,7 @@ def _conv2d_winograd_nhwc_impl(
name="data_pack",
attrs={
"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu",
"r_a", "r_b"],
- "schedule_rule": "meta_schedule.winograd_data_pack.cpu",
+ "schedule_rule": target_kind,
},
# the attrs are necessary hints for the auto-scheduler
)
diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc
index 1ecb537..1ebec19 100644
--- a/src/meta_schedule/integration.cc
+++ b/src/meta_schedule/integration.cc
@@ -28,17 +28,24 @@ namespace meta_schedule {
/**************** Utility functions ****************/
template <class FunctionType>
-bool HasOnlyOneFunction(const IRModule& mod) {
+Optional<FunctionType> GetOnlyOneFunction(const IRModule& mod) {
if (mod->functions.size() != 1) {
- return false;
+ return NullOpt;
}
for (const auto& kv : mod->functions) {
const BaseFunc& func = kv.second;
if (!func->IsInstance<typename FunctionType::ContainerType>()) {
- return false;
+ return NullOpt;
+ } else {
+ return Downcast<FunctionType>(func);
}
}
- return true;
+ return NullOpt;
+}
+
+template <class FunctionType>
+bool HasOnlyOneFunction(const IRModule& mod) {
+ return GetOnlyOneFunction<FunctionType>(mod).defined();
}
/**************** ExtractedTask ****************/
@@ -129,14 +136,17 @@ Optional<ObjectRef>
ApplyHistoryBestNode::Query(runtime::String task_name, IRMod
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records =
database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
- LOG(INFO) << "Applied history best for " << task_name << ".";
+ LOG(INFO) << "Applied history best for: " << task_name;
tir::Schedule sch =
tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1,
/*debug_mask=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
records[0]->trace->ApplyToSchedule(sch, false);
- return sch->mod();
+ tir::PrimFunc func =
GetOnlyOneFunction<tir::PrimFunc>(sch->mod()).value();
+ LOG(INFO) << "\n" << tir::AsTVMScript(func);
+ return func;
}
}
+ LOG(WARNING) << "Cannot find workload: " << task_name << "\n" <<
tir::AsTVMScript(prim_mod);
return NullOpt;
}
diff --git a/src/meta_schedule/schedule_rule/winograd.cc
b/src/meta_schedule/schedule_rule/winograd.cc
index 44db6f2..d8aab3a 100644
--- a/src/meta_schedule/schedule_rule/winograd.cc
+++ b/src/meta_schedule/schedule_rule/winograd.cc
@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse")
return {sch};
});
-TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cpu")
+TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm")
.set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
BlockRV input_tile = GetOnlyProducer(sch, data_pack);
BlockRV data_pad = GetOnlyProducer(sch, input_tile);
diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
index 3b4164c..91e2c41 100644
--- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
+++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py
@@ -15,37 +15,37 @@
# specific language governing permissions and limitations
# under the License.
""" Test Meta Schedule Builder """
+# pylint: disable=missing-docstring
+
import sys
+from typing import List
+
import pytest
-import itertools
import tvm
from tvm import relay
+from tvm.meta_schedule.arg_info import TensorInfo
+from tvm.meta_schedule.builder import BuilderInput, LocalBuilder
+from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput
+from tvm.meta_schedule.testing.custom_builder_runner import (
+ build_relay,
+ build_relay_with_tensorrt,
+ run_with_graph_executor,
+)
+from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.relay import testing
from tvm.relay.op.contrib import tensorrt
-import numpy as np
-from typing import List
-from tvm._ffi import register_func
from tvm.target import Target
-from tvm.runtime import Module
-from tvm.meta_schedule.arg_info import TensorInfo
-from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult
-from tvm.meta_schedule.runner import (
- EvaluatorConfig,
- LocalRunner,
- RunnerInput,
-)
-
from tvm.tir import FloatImm
-from tvm.meta_schedule.testing import get_network
has_tensorrt_codegen = pytest.mark.skipif(
- not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT
codegen not available"
+ not tvm.get_global_func("relay.ext.tensorrt", True),
+ reason="TensorRT codegen not available",
)
has_tensorrt_runtime = pytest.mark.skipif(
- not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not
available"
+ not tensorrt.is_tensorrt_runtime_enabled(),
+ reason="TensorRT runtime not available",
)
-
# conv2d+relu network
def get_conv2d_relu(
data_shape,
@@ -83,105 +83,52 @@ def get_conv2d_relu(
def verify_meta_schedule_with_tensorrt(
- mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool =
True, mode: str = "vm"
+ mod,
+ params,
+ data_shape,
+ use_trt: bool = True,
):
- if use_meta_sched:
- # With meta_schedule
- dev = "cuda"
-
- # Build
- if use_trt:
- from tvm.meta_schedule.testing import relay_build_with_tensorrt
-
- builder = LocalBuilder(f_build=relay_build_with_tensorrt)
- else:
-
- def relay_build_without_tensorrt(
- mod: Module,
- target: Target,
- params: dict,
- ) -> List[BuilderResult]:
- return tvm.relay.build_module._build_module_no_factory(mod,
"cuda", "llvm", params)
-
- builder = LocalBuilder(f_build=relay_build_without_tensorrt)
-
- builder_input = BuilderInput(mod, Target(dev, host="llvm"), params)
-
- (builder_result,) = builder.build([builder_input])
- assert builder_result.error_msg is None
- assert builder_result.artifact_path is not None
-
- # Run
- evaluator_config = EvaluatorConfig(
+ # Build
+ builder = LocalBuilder(
+ f_build=build_relay_with_tensorrt if use_trt else build_relay,
+ timeout_sec=1000,
+ )
+ builder_input = BuilderInput(mod, Target("cuda"), params)
+ builder_result = builder.build([builder_input])[0]
+ assert builder_result.error_msg is None, builder_result.error_msg
+ assert builder_result.artifact_path is not None
+
+ # Run
+ runner_input = RunnerInput(
+ builder_result.artifact_path,
+ device_type="cuda",
+ args_info=[TensorInfo("float32", data_shape)],
+ )
+ runner = LocalRunner(
+ evaluator_config=EvaluatorConfig(
number=5,
repeat=2,
min_repeat_ms=0,
enable_cpu_cache_flush=False,
- )
-
- runner_input = RunnerInput(
- builder_result.artifact_path, "cuda", [TensorInfo("float32",
data_shape)]
- )
-
- def eval_func(rt_mod, device, evaluator_config, repeated_args):
- rt_mod =
tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device))
-
- eval = rt_mod.module.time_evaluator(
- func_name="run",
- dev=device,
- number=evaluator_config.number,
- repeat=evaluator_config.repeat,
- min_repeat_ms=evaluator_config.min_repeat_ms,
- f_preproc="cache_flush_cpu_non_first_arg"
- if evaluator_config.enable_cpu_cache_flush
- else "",
- )
- repeated_costs: List[List[float]] = []
- for args in repeated_args:
- profile_result = eval(*args)
- repeated_costs.append(profile_result.results)
-
- costs = [float(cost) for cost in
itertools.chain.from_iterable(repeated_costs)]
- return costs
-
- runner = LocalRunner(
- evaluator_config=evaluator_config,
- f_run_evaluator=eval_func,
- )
-
- # Run the module
- (runner_future,) = runner.run([runner_input])
- runner_result = runner_future.result()
- assert runner_result is not None
- assert runner_result.run_secs is not None
- assert runner_result.error_msg is None
-
- for result in runner_result.run_secs:
- if isinstance(result, FloatImm):
- result = result.value
- assert isinstance(result, float)
- assert result >= 0.0
-
- else:
- # Without meta_schedule
- if use_trt:
- mod, config = tensorrt.partition_for_tensorrt(mod)
- with tvm.transform.PassContext(
- opt_level=3, config={"relay.ext.tensorrt.options": config}
- ):
- func = relay.create_executor(
- mode, mod=mod, device=tvm.cuda(0), target="cuda"
- ).evaluate()
- else:
- with tvm.transform.PassContext(opt_level=3):
- func = relay.create_executor(
- mode, mod=mod, device=tvm.cuda(0), target="cuda",
params=params
- ).evaluate()
-
-
[email protected]_cuda
+ ),
+ f_run_evaluator=run_with_graph_executor,
+ )
+
+ # Run the module
+ runner_future = runner.run([runner_input])[0]
+ runner_result = runner_future.result()
+ assert runner_result is not None
+ assert runner_result.error_msg is None, runner_result.error_msg
+ assert runner_result.run_secs is not None
+
+ for result in runner_result.run_secs:
+ if isinstance(result, FloatImm):
+ result = result.value
+ assert isinstance(result, float)
+ assert result >= 0.0
+
+
@has_tensorrt_codegen
-@has_tensorrt_runtime
def test_conv2d_relu():
data_shape = (1, 1280, 14, 14)
out_channels = 256
@@ -206,21 +153,17 @@ def test_conv2d_relu():
verify_meta_schedule_with_tensorrt(mod, params, data_shape)
[email protected]_cuda
@has_tensorrt_codegen
-@has_tensorrt_runtime
[email protected](
- "model_name",
- ["resnet-50", "mobilenet"],
-)
[email protected]("batch_size", [1])
[email protected]("use_meta_sched", [True])
[email protected]("model_name", ["resnet_50"])
[email protected]("input_shape", [[1, 3, 224, 224]])
@pytest.mark.parametrize("use_trt", [True, False])
-def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool,
use_trt: bool):
-
- mod, params, input_shape, output_shape = get_network(name=model_name,
batch_size=batch_size)
+def test_relay_model(model_name: str, input_shape: List[int], use_trt: bool):
+ mod, params, _ = get_network(model_name, input_shape)
verify_meta_schedule_with_tensorrt(
- mod, params, input_shape, use_meta_sched=use_meta_sched,
use_trt=use_trt, mode="vm"
+ mod,
+ params,
+ input_shape,
+ use_trt,
)
diff --git a/tests/python/unittest/test_meta_schedule_integration.py
b/tests/python/unittest/test_meta_schedule_integration.py
index 3676e3a..50dc928 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -18,30 +18,29 @@ import sys
from typing import List
import pytest
-
import tvm
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
-from tvm.meta_schedule.utils import derived_object
-from tvm.tir import Schedule
-from tvm.target import Target
-from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.integration import (
+ ApplyHistoryBest,
ExtractedTask,
MetaScheduleContext,
TaskExtraction,
- ApplyHistoryBest,
)
-from tvm.meta_schedule.testing import get_network
+from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir import Schedule
-# pylint:
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
+# pylint:
disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name
@tvm.script.ir_module
class MockModule:
@T.prim_func
- def main(a: T.handle, b: T.handle) -> None: # pylint:
disable=no-self-argument
+ def main(a: T.handle, b: T.handle) -> None: # type: ignore
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (16,), "float32")
B = T.match_buffer(b, (16,), "float32")
@@ -51,7 +50,17 @@ class MockModule:
B[vi] = A[vi]
-# pylint:
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
+# pylint:
enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument
+
+
+def _has_torch():
+ import importlib.util # pylint:
disable=unused-import,import-outside-toplevel
+
+ spec = importlib.util.find_spec("torch")
+ return spec is not None
+
+
+requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not
installed")
def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule):
@@ -63,13 +72,9 @@ def _check_mock_task(tasks: List[ExtractedTask], mod:
IRModule):
tvm.ir.assert_structural_equal(tir_mod, MockModule)
+@requires_torch
def test_meta_schedule_integration_task_extraction_query():
- mod, _, _, _ = get_network(
- name="resnet-18",
- batch_size=1,
- layout="NHWC",
- dtype="float32",
- )
+ mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
env = TaskExtraction()
env.query(task_name="mock-task", mod=mod, target=Target("llvm"),
dispatched=[MockModule])
_check_mock_task(env.tasks, mod)
@@ -93,13 +98,9 @@ def test_meta_schedule_integration_multiple_current():
...
+@requires_torch
def test_meta_schedule_integration_query_inside_with_scope():
- mod, _, _, _ = get_network(
- name="resnet-18",
- batch_size=1,
- layout="NHWC",
- dtype="float32",
- )
+ mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
env = TaskExtraction()
with env:
MetaScheduleContext.query_inside_with_scope(
@@ -111,17 +112,43 @@ def
test_meta_schedule_integration_query_inside_with_scope():
_check_mock_task(env.tasks, mod)
+@requires_torch
def test_meta_schedule_integration_extract_from_resnet():
- mod, params, _, _ = get_network(
- name="resnet-18",
- batch_size=1,
- layout="NHWC",
- dtype="float32",
- )
+ mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224,
224])
extracted_tasks = ms.integration.extract_task_from_relay(mod,
target="llvm", params=params)
- assert len(extracted_tasks) == 30
-
-
+ expected_task_names = [
+ "vm_mod_fused_" + s
+ for s in [
+ "nn_max_pool2d",
+ "nn_adaptive_avg_pool2d",
+ "nn_dense_add",
+ "nn_conv2d_add",
+ "nn_conv2d_add_1",
+ "nn_conv2d_add_2",
+ "nn_conv2d_add_add_nn_relu",
+ "nn_conv2d_add_add_nn_relu_1",
+ "nn_conv2d_add_nn_relu",
+ "nn_conv2d_add_nn_relu_1",
+ "nn_conv2d_add_nn_relu_2",
+ "nn_conv2d_add_nn_relu_3",
+ "nn_conv2d_add_nn_relu_4",
+ "nn_conv2d_add_nn_relu_5",
+
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
+
"nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
+ "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
+
"nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
+ # The two tasks below are purely spatial and are ruled out by
AutoScheduler
+ "layout_transform",
+ "layout_transform_reshape_squeeze",
+ ]
+ ]
+
+ assert len(extracted_tasks) == 20
+ for t in extracted_tasks:
+ assert t.task_name in expected_task_names, t.task_name
+
+
+@requires_torch
def test_meta_schedule_integration_apply_history_best():
@derived_object
class DummyDatabase(PyDatabase):
@@ -161,12 +188,7 @@ def test_meta_schedule_integration_apply_history_best():
def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))
- mod, _, _, _ = get_network(
- name="resnet-18",
- batch_size=1,
- layout="NHWC",
- dtype="float32",
- )
+ mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
database = DummyDatabase()
env = ApplyHistoryBest(database)
target = Target("llvm")
@@ -175,6 +197,7 @@ def test_meta_schedule_integration_apply_history_best():
TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
)
mod = env.query(task_name="mock-task", mod=mod, target=target,
dispatched=[MockModule])
+ mod = IRModule({"main": mod})
assert tvm.ir.structural_equal(mod, workload.mod)
diff --git a/tests/python/unittest/test_meta_schedule_runner.py
b/tests/python/unittest/test_meta_schedule_runner.py
index c2fb92c..09e708f 100644
--- a/tests/python/unittest/test_meta_schedule_runner.py
+++ b/tests/python/unittest/test_meta_schedule_runner.py
@@ -23,8 +23,8 @@ from typing import Any, List
import numpy as np
import pytest
-
import tvm
+import tvm.testing
from tvm._ffi import register_func
from tvm.meta_schedule.arg_info import TensorInfo
from tvm.meta_schedule.builder import BuilderInput, LocalBuilder
@@ -37,21 +37,25 @@ from tvm.meta_schedule.runner import (
RunnerFuture,
RunnerInput,
)
+from tvm.meta_schedule.runner.local_runner import (
+ default_alloc_argument as local_default_alloc_argument,
+)
from tvm.meta_schedule.runner.rpc_runner import (
- default_alloc_argument as rpc_default_alloc_argument,
T_ARG_INFO_JSON_OBJ_LIST,
T_ARGUMENT_LIST,
)
-from tvm.meta_schedule.runner.local_runner import (
- default_alloc_argument as local_default_alloc_argument,
+from tvm.meta_schedule.runner.rpc_runner import (
+ default_alloc_argument as rpc_default_alloc_argument,
+)
+from tvm.meta_schedule.testing.local_rpc import LocalRPC
+from tvm.meta_schedule.utils import (
+ derived_object,
+ get_global_func_with_default_on_worker,
)
-from tvm.meta_schedule.testing import LocalRPC
-from tvm.meta_schedule.utils import derived_object,
get_global_func_with_default_on_worker
from tvm.rpc import RPCSession
from tvm.runtime import Device, Module
from tvm.script import tir as T
from tvm.target import Target
-import tvm.testing
from tvm.tir import FloatImm
MATMUL_N = 16
@@ -886,4 +890,4 @@ def test_meta_schedule_local_runner_add_test():
if __name__ == "__main__":
- test_meta_schedule_local_single_run()
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py
b/tests/python/unittest/test_meta_schedule_tune_relay.py
index 1443110..6bf59d2 100644
--- a/tests/python/unittest/test_meta_schedule_tune_relay.py
+++ b/tests/python/unittest/test_meta_schedule_tune_relay.py
@@ -17,27 +17,20 @@
# pylint: disable=missing-docstring
import logging
import tempfile
-import pytest
-import numpy as np
-from typing import Tuple, List
-
-from tvm.meta_schedule.utils import derived_object
-
-try:
- import torch
-except ModuleNotFoundError:
- pass
+from typing import List
+import numpy as np
+import pytest
import tvm
from tvm import relay
-from tvm.ir import IRModule
-from tvm.runtime.ndarray import cpu, cuda
-from tvm.target.target import Target
from tvm.contrib import graph_executor
+from tvm.ir import IRModule
from tvm.meta_schedule import ReplayTraceConfig
-from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
-from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model
+from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
+from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.tune import tune_relay
+from tvm.meta_schedule.utils import derived_object
+from tvm.target.target import Target
logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
@@ -83,47 +76,33 @@ class DummyDatabase(PyDatabase):
@pytest.mark.skip("Integration test")
[email protected]("model_name", ["resnet18", "mobilenet_v2",
"bert_base"])
[email protected]("batch_size", [1])
[email protected]("target", ["llvm --num-cores=16",
"nvidia/geforce-rtx-3070"])
-def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target:
str):
- if model_name == "inception_v3" and batch_size == 1:
- pytest.skip("inception_v3 does not handle batch_size of 1")
-
- input_shape: Tuple[int, ...]
- input_name = "input0"
- dev = tvm.cpu() if str(target).startswith("llvm") else cuda()
- if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
- seq_length = 128
- input_name = "input_ids"
- input_shape = (batch_size, seq_length)
[email protected](
+ "model_name, input_shape, target",
+ [
+ ("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16"),
+ ("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"),
+ ("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=16"),
+ ("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070"),
+ ("bert_base", [1, 64], "llvm --num-cores=16"),
+ ("bert_base", [1, 64], "nvidia/geforce-rtx-3070"),
+ ],
+)
+def test_meta_schedule_tune_relay(
+ model_name: str,
+ input_shape: List[int],
+ target: str,
+):
+ dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda()
+ if model_name.startswith("bert"):
data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape),
dev) # embedding size
else:
- if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
- input_shape = (batch_size, 3, 299, 299)
- elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION:
- input_shape = (batch_size, 3, 299, 299)
- elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION:
- input_shape = (1, 3, 300, 300)
- elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION:
- input_shape = (batch_size, 3, 3, 299, 299)
- else:
- raise ValueError("Unsupported model: " + model_name)
data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"),
dev)
- output_shape: Tuple[int, int] = (batch_size, 1000)
-
- mod, params = get_torch_model(
- model_name=model_name,
- input_shape=input_shape,
- output_shape=output_shape,
- dtype="float32",
- )
-
+ mod, params, (input_name, _, _) = get_network(name=model_name,
input_shape=input_shape)
+ target = Target(target)
with tempfile.TemporaryDirectory() as work_dir:
- target = Target(target)
database = DummyDatabase()
- rt_mod: tvm.module = tune_relay(
+ rt_mod: tvm.runtime.Module = tune_relay(
mod=mod,
params=params,
target=target,
@@ -136,7 +115,7 @@ def test_meta_schedule_tune_relay(model_name: str,
batch_size: int, target: str)
)
# Compile without meta-scheduler for correctness check
with tvm.transform.PassContext(opt_level=0):
- rt_mod2 = relay.build(mod, target=target, params=params)
+ rt_mod2 = relay.build(mod, target=Target("llvm"), params=params)
def get_output(data, lib):
module = graph_executor.GraphModule(lib["default"](dev))
@@ -146,14 +125,14 @@ def test_meta_schedule_tune_relay(model_name: str,
batch_size: int, target: str)
# Check correctness
actual_output = get_output(data, rt_mod)
- expected_output = get_output(data, rt_mod2)
+ expected_output = get_output(tvm.nd.array(data.numpy(),
device=tvm.cpu()), rt_mod2)
assert np.allclose(actual_output, expected_output, rtol=1e-4,
atol=2e-4)
if __name__ == """__main__""":
- test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16")
- test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
- test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16")
- test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070")
- test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16")
- test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")
+ test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm
--num-cores=16")
+ test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224],
"nvidia/geforce-rtx-3070")
+ test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm
--num-cores=16")
+ test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224],
"nvidia/geforce-rtx-3070")
+ test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=16")
+ test_meta_schedule_tune_relay("bert_base", [1, 64],
"nvidia/geforce-rtx-3070")