This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 18ce8e4 [TVMC] A simplified TVMC API for python scripting (Part 1).
(#7823)
18ce8e4 is described below
commit 18ce8e4b82369fb28f832fad78dbb1c12099a66d
Author: Josh Fromm <[email protected]>
AuthorDate: Tue May 4 07:59:34 2021 -0700
[TVMC] A simplified TVMC API for python scripting (Part 1). (#7823)
* Introduce new TVMC Python API.
* Add simple testing model.
* Split result utils into stand-alone file.
---
python/tvm/auto_scheduler/search_task.py | 83 ++++--
python/tvm/auto_scheduler/task_scheduler.py | 5 +-
python/tvm/driver/tvmc/__init__.py | 3 +
python/tvm/driver/tvmc/autotuner.py | 341 ++++++++++++++++-------
python/tvm/driver/tvmc/common.py | 5 +-
python/tvm/driver/tvmc/compiler.py | 160 ++++-------
python/tvm/driver/tvmc/frontends.py | 21 +-
python/tvm/driver/tvmc/model.py | 364 +++++++++++++++++++++++++
python/tvm/driver/tvmc/result_utils.py | 60 ++++
python/tvm/driver/tvmc/runner.py | 274 ++++++++-----------
src/auto_scheduler/search_task.cc | 5 +
tests/python/driver/tvmc/conftest.py | 51 +++-
tests/python/driver/tvmc/test_autoscheduler.py | 65 ++---
tests/python/driver/tvmc/test_autotuner.py | 79 +++---
tests/python/driver/tvmc/test_command_line.py | 53 ++++
tests/python/driver/tvmc/test_compiler.py | 163 ++++++-----
tests/python/driver/tvmc/test_frontends.py | 54 ++--
tests/python/driver/tvmc/test_model.py | 65 +++++
tests/python/driver/tvmc/test_runner.py | 24 +-
tests/python/driver/tvmc/test_tvmc_common.py | 15 +-
20 files changed, 1310 insertions(+), 580 deletions(-)
diff --git a/python/tvm/auto_scheduler/search_task.py
b/python/tvm/auto_scheduler/search_task.py
index 6e73ab1..fca8894 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -43,40 +43,74 @@ logger = logging.getLogger("auto_scheduler")
@tvm._ffi.register_object("auto_scheduler.HardwareParams")
class HardwareParams(Object):
- """The parameters of target hardware used to guide the search policy
+ """The parameters of target hardware used to guide the search policy.
+
+ When a parameter isn't provided, it will instead use the
+ current machine's default value if target is specified.
TODO(jcf94): This is considered to be merged with the new Target
specification:
https://discuss.tvm.apache.org/t/rfc-tvm-target-specification/6844
Parameters
----------
- num_cores : int
+ num_cores : int, optional
The number of device cores.
- vector_unit_bytes : int
+ vector_unit_bytes : int, optional
The width of vector units in bytes.
- cache_line_bytes : int
+ cache_line_bytes : int, optional
The size of cache line in bytes.
- max_shared_memory_per_block : int
+ max_shared_memory_per_block : int, optional
The max shared memory per block in bytes.
- max_local_memory_per_block : int
+ max_local_memory_per_block : int, optional
The max local memory per block in bytes.
- max_threads_per_block : int
+ max_threads_per_block : int, optional
The max number of threads per block.
- max_vthread_extent : int
+ max_vthread_extent : int, optional
The max vthread extent.
- warp_size : int
+ warp_size : int, optional
The thread numbers of a warp.
+ target : str or Target, optional
+ The compilation target. Used to determine default values if provided.
+ target_host : str or Target, optional
+ The compilation target host. Used to determine default values if
provided.
"""
def __init__(
self,
- num_cores,
- vector_unit_bytes,
- cache_line_bytes,
- max_shared_memory_per_block,
- max_local_memory_per_block,
- max_threads_per_block,
- max_vthread_extent,
- warp_size,
+ num_cores=None,
+ vector_unit_bytes=None,
+ cache_line_bytes=None,
+ max_shared_memory_per_block=None,
+ max_local_memory_per_block=None,
+ max_threads_per_block=None,
+ max_vthread_extent=None,
+ warp_size=None,
+ target=None,
+ target_host=None,
):
+ # If target is provided, get the default paramters for this machine.
+ if target is not None:
+ if isinstance(target, str):
+ target = tvm.target.Target(target)
+ if isinstance(target_host, str):
+ target_host = tvm.target.Target(target_host)
+ default_params = _ffi_api.GetDefaultHardwareParams(target,
target_host)
+
+ if num_cores is None:
+ num_cores = default_params.num_cores
+ if vector_unit_bytes is None:
+ vector_unit_bytes = default_params.vector_unit_bytes
+ if cache_line_bytes is None:
+ cache_line_bytes = default_params.cache_line_bytes
+ if max_shared_memory_per_block is None:
+ max_shared_memory_per_block =
default_params.max_shared_memory_per_block
+ if max_local_memory_per_block is None:
+ max_local_memory_per_block =
default_params.max_local_memory_per_block
+ if max_threads_per_block is None:
+ max_threads_per_block = default_params.max_threads_per_block
+ if max_vthread_extent is None:
+ max_vthread_extent = default_params.max_vthread_extent
+ if warp_size is None:
+ warp_size = default_params.warp_size
+
self.__init_handle_by_constructor__(
_ffi_api.HardwareParams,
num_cores,
@@ -89,6 +123,21 @@ class HardwareParams(Object):
warp_size,
)
+ def __str__(self):
+ """Pretty printing for hardware parameter configuration."""
+ format_str = (
+ "HardwareParams:\n"
+ f" num_cores: {self.num_cores}\n"
+ f" vector_unit_bytes: {self.vector_unit_bytes}\n"
+ f" cache_line_bytes: {self.cache_line_bytes}\n"
+ f" max_shared_memory_per_block:
{self.max_shared_memory_per_block}\n"
+ f" max_local_memory_per_block:
{self.max_local_memory_per_block}\n"
+ f" max_threads_per_block: {self.max_threads_per_block}\n"
+ f" max_vthread_extent: {self.max_vthread_extent}\n"
+ f" warp_size: {self.warp_size}\n"
+ )
+ return format_str
+
@tvm._ffi.register_object("auto_scheduler.TuningOptions")
class TuningOptions(Object):
diff --git a/python/tvm/auto_scheduler/task_scheduler.py
b/python/tvm/auto_scheduler/task_scheduler.py
index 0221870..5cae556 100644
--- a/python/tvm/auto_scheduler/task_scheduler.py
+++ b/python/tvm/auto_scheduler/task_scheduler.py
@@ -329,7 +329,10 @@ class TaskScheduler:
tune_option.num_measures_per_round, tune_option.num_measure_trials
// len(self.tasks)
)
if self.num_measures_per_round <= 0:
- raise ValueError("num_measure_trials is too small. Please set it
to a higher value.")
+ raise ValueError(
+ "num_measure_trials is too small. Please set it to a higher
value."
+ f"It should be at least {len(self.tasks)} for this model."
+ )
# restore the status of the task scheduler from a log file
if self.load_log_file:
diff --git a/python/tvm/driver/tvmc/__init__.py
b/python/tvm/driver/tvmc/__init__.py
index d9c1579..42184c3 100644
--- a/python/tvm/driver/tvmc/__init__.py
+++ b/python/tvm/driver/tvmc/__init__.py
@@ -22,6 +22,9 @@ TVMC - TVM driver command-line interface
from . import autotuner
from . import compiler
from . import runner
+from . import result_utils
from .frontends import load_model as load
from .compiler import compile_model as compile
from .runner import run_module as run
+from .autotuner import tune_model as tune
+from .model import TVMCModel, TVMCPackage, TVMCResult
diff --git a/python/tvm/driver/tvmc/autotuner.py
b/python/tvm/driver/tvmc/autotuner.py
index bdb4c62..8f94c53 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -20,10 +20,14 @@ Provides support to auto-tuning networks using AutoTVM.
import os.path
import logging
import time
+from copy import deepcopy
+from typing import Optional, Dict, List, Union
from urllib.parse import urlparse
+import tvm
from tvm import autotvm, auto_scheduler
+from tvm.auto_scheduler.search_task import HardwareParams
from tvm.autotvm.tuner import GATuner
from tvm.autotvm.tuner import GridSearchTuner
from tvm.autotvm.tuner import RandomTuner
@@ -33,6 +37,7 @@ from tvm.target import Target
from . import common, composite_target, frontends
from .common import TVMCException
from .main import register_parser
+from .model import TVMCModel
# pylint: disable=invalid-name
@@ -52,7 +57,7 @@ def add_tune_parser(subparsers):
)
# There is some extra processing required to define the actual default
value
- # for --min-repeat-ms. This is done in `drive_tune`.
+ # for --min-repeat-ms. This is done in `tune_model`.
parser.add_argument(
"--min-repeat-ms",
default=None,
@@ -93,7 +98,8 @@ def add_tune_parser(subparsers):
)
parser.add_argument(
"--rpc-key",
- help="the RPC tracker key of the target device. Required when
--rpc-tracker is provided.",
+ help="the RPC tracker key of the target device. "
+ "Required when --rpc-tracker is provided.",
)
parser.add_argument(
"--rpc-tracker",
@@ -142,50 +148,50 @@ def add_tune_parser(subparsers):
auto_scheduler_group.add_argument(
"--cache-line-bytes",
type=int,
- default=64,
- help="the size of cache line in bytes",
+ help="the size of cache line in bytes. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--num-cores",
type=int,
- default=4,
- help="the number of device cores",
+ help="the number of device cores. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--vector-unit-bytes",
type=int,
- default=16,
- help="the width of vector units in bytes",
+ help="the width of vector units in bytes. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--max-shared-memory-per-block",
type=int,
- default=0,
- help="the max shared memory per block in bytes",
+ help="the max shared memory per block in bytes. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--max-local-memory-per-block",
type=int,
- default=0,
- help="the max local memory per block in bytes",
+ help="the max local memory per block in bytes. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--max-threads-per-block",
type=int,
- default=0,
- help="the max number of threads per block",
+ help="the max number of threads per block. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--max-vthread-extent",
type=int,
- default=0,
- help="the max vthread extent",
+ help="the max vthread extent. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--warp-size",
type=int,
- default=0,
- help="the thread numbers of a warp",
+ help="the thread numbers of a warp. "
+ "If not specified, it will be autoset for the current machine.",
)
auto_scheduler_group.add_argument(
"--include-simple-tasks",
@@ -216,7 +222,6 @@ def add_tune_parser(subparsers):
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
type=common.parse_shape_string,
- default=None,
)
@@ -228,8 +233,22 @@ def drive_tune(args):
args: argparse.Namespace
Arguments from command line parser.
"""
- # extra arguments validation before importing the model, so that obvious
errors
- # are pointed in advance.
+ tvmc_model = frontends.load_model(args.FILE, args.model_format,
shape_dict=args.input_shapes)
+
+ # Specify hardware parameters, although they'll only be used if
autoscheduling.
+ hardware_params = auto_scheduler.HardwareParams(
+ num_cores=args.num_cores,
+ vector_unit_bytes=args.vector_unit_bytes,
+ cache_line_bytes=args.cache_line_bytes,
+ max_shared_memory_per_block=args.max_shared_memory_per_block,
+ max_local_memory_per_block=args.max_local_memory_per_block,
+ max_threads_per_block=args.max_threads_per_block,
+ max_vthread_extent=args.max_vthread_extent,
+ warp_size=args.warp_size,
+ target=args.target,
+ target_host=args.target_host,
+ )
+
if args.rpc_tracker:
parsed_url = urlparse("//%s" % args.rpc_tracker)
rpc_hostname = parsed_url.hostname
@@ -241,11 +260,127 @@ def drive_tune(args):
raise common.TVMCException(
"need to provide an RPC tracker key (--rpc-key) for remote
tuning"
)
+ else:
+ rpc_host_name = None
+ rpc_port = None
+
+ tune_model(
+ tvmc_model,
+ args.target,
+ tuning_records=args.output,
+ prior_records=args.tuning_records,
+ enable_autoscheduler=args.enable_autoscheduler,
+ rpc_key=args.rpc_key,
+ hostname=rpc_host_name,
+ port=rpc_port,
+ trials=args.trials,
+ target_host=args.target_host,
+ tuner=args.tuner,
+ min_repeat_ms=args.min_repeat_ms,
+ early_stopping=args.early_stopping,
+ desired_layout=args.desired_layout,
+ timeout=args.timeout,
+ repeat=args.repeat,
+ number=args.number,
+ parallel=args.parallel,
+ hardware_params=hardware_params,
+ include_simple_tasks=args.include_simple_tasks,
+ log_estimated_latency=args.log_estimated_latency,
+ )
+
+
+def tune_model(
+ tvmc_model: TVMCModel,
+ target: str,
+ tuning_records: Optional[str] = None,
+ prior_records: Optional[str] = None,
+ enable_autoscheduler: bool = False,
+ rpc_key: Optional[str] = None,
+ hostname: Optional[str] = None,
+ port: Optional[Union[int, str]] = 9090,
+ trials: int = 10000,
+ target_host: Optional[str] = None,
+ tuner: str = "xgb",
+ min_repeat_ms: Optional[int] = None,
+ early_stopping: Optional[int] = None,
+ desired_layout: Optional[str] = None,
+ timeout: int = 10,
+ repeat: int = 1,
+ number: int = 10,
+ parallel: int = 4,
+ hardware_params: Optional[HardwareParams] = None,
+ include_simple_tasks: bool = False,
+ log_estimated_latency: bool = False,
+):
+ """Use tuning to automatically optimize the functions in a model.
+
+ Parameters
+ ----------
+ tvmc_model : TVMCModel
+ The model to be optimized.
+ target : str
+ Compilation target as plain string, inline JSON or path to a JSON file.
+ tuning_records: str, optional
+ The path to a file that tuning results will be saved to. If not
specified,
+ a temporary file will be used.
+ prior_records: str, optional
+ A path to previous tuning results that will be used to hot-start the
tuning
+ cost model if provided.
+ enable_autoscheduler : bool, optional
+ When true, use autoscheduling rather than autotvm. This should produce
+ faster kernels for compatible model-target pairs.
+ rpc_key : str, optional
+ The RPC tracker key of the target device. Required when rpc_tracker is
provided.
+ host_name : str, optional
+ The IP address of an RPC tracker, used when benchmarking remotely.
+ port : int or str, optional
+ The port of the RPC tracker to connect to. Defaults to 9090.
+ trials : int, optional
+ The number of schedules to try out for the entire model. Note that the
default
+ value is chosen as a decent average for most models, but larger models
may need
+ more trials to reach a good result while smaller models will converge
with fewer
+ trials.
+ tuner : str, optional
+ The type of tuner to use when tuning with autotvm. Can be one of
+ "ga", "gridsearch", "random", "xgb", "xgb_knob", and "xgb-rank".
+ min_repeat_ms : int, optional
+ Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other
targets.
+ early_stopping : int, optional
+ When specified, stop tuning after this number of trials if results
aren't improving.
+ desired_layout : str, optional
+ Can be one of "NCHW" or "NHWC". When specified, compatible operations
in the graph
+ will have their layout set to this format. Tasks will then be tuned
using this
+ specified layout.
+ timeout : int, optional,
+ If a kernel trial lasts longer than this duration in seconds, it will
be
+ considered a failure.
+ repeat : int, optional
+ How many times each measurement should be repeated.
+ number : int, optional
+ The number of runs a single repeat is made of.
+ parallel : int, optional
+ The maximum number of parallel devices to use when tuning.
+ hardware_params : auto_scheduler.HardwareParams, optional
+ When using the autoscheduler, this object defines the configuration of
the target hardware.
+ include_simple_tasks : bool, optional
+ Whether to extract simple operations or only computationally intensive
ones when using
+ the autoscheduler.
+ log_estimated_latency : bool, optional
+ If using the autoscheduler, write the estimated latency at each step
of tuning to file.
- target, extra_targets = common.target_from_cli(args.target)
- target_host = args.target_host
+ Returns
+ -------
+ tuning_records : str
+ The path to the produced tuning log file.
+ """
+ target, extra_targets = common.target_from_cli(target)
target, target_host = Target.check_and_update_host_consist(target,
target_host)
- mod, params = frontends.load_model(args.FILE, args.model_format,
shape_dict=args.input_shapes)
+ # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates
source
+ # model is fixed. For now, creating a clone avoids the issue.
+ mod = deepcopy(tvmc_model.mod)
+ params = tvmc_model.params
+ if tuning_records is None:
+ tuning_records = tvmc_model.default_tuning_records_path()
for codegen_from_cli in extra_targets:
codegen =
composite_target.get_codegen_by_target(codegen_from_cli["name"])
@@ -255,97 +390,113 @@ def drive_tune(args):
# min_repeat_ms should be:
# a. the value provided by the user, if any, or
# b. 0ms in case target is "cpu"; otherwise 1000ms
- if args.min_repeat_ms is not None:
- min_repeat_ms = args.min_repeat_ms
- else:
+ if min_repeat_ms is None:
min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000
- logger.debug("Default --min-repeat-ms for this target is %s",
min_repeat_ms)
+ logger.info("Default --min-repeat-ms for this target is %s",
min_repeat_ms)
- if args.rpc_tracker:
- runner_ctor = auto_scheduler.RPCRunner if args.enable_autoscheduler
else autotvm.RPCRunner
+ if rpc_key:
+ if hostname is None or port is None:
+ raise common.TVMCException(
+ "You must provide a hostname and port to connect to a remote
RPC device."
+ )
+ if isinstance(port, str):
+ port = int(port)
+
+ logger.info("Tuning will be performed on device %s at %s:%d.",
rpc_key, hostname, port)
+
+ runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else
autotvm.RPCRunner
runner = runner_ctor(
- key=args.rpc_key,
- host=rpc_hostname,
- port=rpc_port,
- number=args.number,
- repeat=args.repeat,
- n_parallel=args.parallel,
- timeout=args.timeout,
+ key=rpc_key,
+ host=hostname,
+ port=port,
+ number=number,
+ repeat=repeat,
+ n_parallel=parallel,
+ timeout=timeout,
min_repeat_ms=min_repeat_ms,
)
else:
- logger.info("starting localhost tuning")
+ logger.info("Starting localhost tuning.")
runner_ctor = (
- auto_scheduler.LocalRunner if args.enable_autoscheduler else
autotvm.LocalRunner
+ auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else
autotvm.LocalRunner
)
- runner = runner_ctor(
- number=args.number,
- repeat=args.repeat,
- timeout=args.timeout,
+ local_server = runner_ctor(
+ number=number,
+ repeat=repeat,
+ timeout=timeout,
min_repeat_ms=min_repeat_ms,
)
- if args.enable_autoscheduler:
- # Specify hardware parameters
- hardware_params = auto_scheduler.HardwareParams(
- args.num_cores,
- args.vector_unit_bytes,
- args.cache_line_bytes,
- args.max_shared_memory_per_block,
- args.max_local_memory_per_block,
- args.max_threads_per_block,
- args.max_vthread_extent,
- args.warp_size,
- )
+ # For autoscheduling on some devices, we need to maintain a
LocalRPCMeasureContext object.
+ if enable_autoscheduler:
+ runner = local_server.runner
+ else:
+ runner = local_server
+
+ if enable_autoscheduler:
+
tasks, weights = autoscheduler_get_tuning_tasks(
mod=mod,
params=params,
target=target,
- alter_layout=args.desired_layout,
+ alter_layout=desired_layout,
hardware_params=hardware_params,
- include_simple_tasks=args.include_simple_tasks,
+ include_simple_tasks=include_simple_tasks,
)
# Create the autoscheduler tuning options
tuning_options = auto_scheduler.TuningOptions(
- num_measure_trials=args.trials,
- measure_callbacks=[auto_scheduler.RecordToFile(args.output)],
+ num_measure_trials=trials,
+ measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)],
runner=runner,
- early_stopping=args.early_stopping,
+ early_stopping=early_stopping,
)
+ logger.info("Autoscheduling with configuration: %s", tuning_options)
+
# Schedule the tasks (i.e., produce a schedule for each task)
- schedule_tasks(
- tasks, weights, tuning_options, args.tuning_records,
args.log_estimated_latency
- )
+ schedule_tasks(tasks, weights, tuning_options, prior_records,
log_estimated_latency)
else:
tasks = autotvm_get_tuning_tasks(
mod=mod,
params=params,
target=target,
- alter_layout=args.desired_layout,
+ alter_layout=desired_layout,
)
- tuning_option = {
- "tuner": args.tuner,
- "trials": args.trials,
- "early_stopping": args.early_stopping,
+ # In autotvm, trials is specified per task. We can convert the
per-model input
+ # provided to per-task trials by dividing by the number of tasks.
+ trials = int(trials / len(tasks))
+ logger.info("Autotuning with %d trials per task.", trials)
+
+ tuning_options = {
+ "tuner": tuner,
+ "trials": trials,
+ "early_stopping": early_stopping,
"measure_option": autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func="default"),
runner=runner
),
- "tuning_records": args.tuning_records,
+ "tuning_records": prior_records,
}
- logger.debug(" tuning options: %s", tuning_option)
+ logger.info("Autotuning with configuration: %s", tuning_options)
- tune_tasks(tasks, args.output, **tuning_option)
+ tune_tasks(tasks, tuning_records, **tuning_options)
+ return tuning_records
-def autotvm_get_tuning_tasks(mod, params, target, target_host=None,
alter_layout=None):
+
+def autotvm_get_tuning_tasks(
+ mod: tvm.IRModule,
+ params: Dict[str, tvm.nd.NDArray],
+ target: str,
+ target_host: Optional[str] = None,
+ alter_layout: Optional[str] = None,
+):
"""Get the autotvm tuning tasks for a given relay module.
Parameters
----------
- mod : tvm.relay.Module
+ mod : tvm.IRModule
The relay module from which to extract tuning tasks.
params : dict
The params for the relay module.
@@ -378,19 +529,19 @@ def autotvm_get_tuning_tasks(mod, params, target,
target_host=None, alter_layout
def autoscheduler_get_tuning_tasks(
- mod,
- params,
- target,
- target_host=None,
- alter_layout=None,
- hardware_params=None,
- include_simple_tasks=False,
+ mod: tvm.IRModule,
+ params: Dict[str, tvm.nd.NDArray],
+ target: str,
+ target_host: Optional[str] = None,
+ alter_layout: Optional[str] = None,
+ hardware_params: Optional[HardwareParams] = None,
+ include_simple_tasks: bool = False,
):
"""Get the autoscheduler tuning tasks for a given relay module.
Parameters
----------
- mod : tvm.relay.Module
+ mod : tvm.IRModule
The relay module from which to extract tuning tasks.
params : dict
The params for the relay module.
@@ -430,7 +581,11 @@ def autoscheduler_get_tuning_tasks(
def schedule_tasks(
- tasks, task_weights, tuning_options, tuning_records=None,
log_estimated_latency=False
+ tasks: List[auto_scheduler.SearchTask],
+ task_weights: List[float],
+ tuning_options: auto_scheduler.TuningOptions,
+ prior_records: Optional[str] = None,
+ log_estimated_latency: bool = False,
):
"""Generate the schedules for the different tasks (i.e., subgraphs)
contained in the module.
Store the schedules in a json file that will be used later by the compiler.
@@ -441,10 +596,12 @@ def schedule_tasks(
A list of auto_scheduler.SearchTask to tune.
task_weights : list
The weight (i.e. the number of appearance) of extracted tasks
- tuning_options: dict
+ tuning_options: auto_scheduler.TuningOptions
The options of tuning
- tuning_records : str, optional
+ prior_records : str, optional
The json file used to preload the autoscheduler
+ log_estimated_latency : bool, optional
+ If true, writes the estimated runtime of the model during each step of
tuning to file.
"""
if not log_estimated_latency:
callbacks = [auto_scheduler.task_scheduler.PrintTableInfo()]
@@ -456,7 +613,7 @@ def schedule_tasks(
# Create the scheduler
tuner = auto_scheduler.TaskScheduler(
- tasks, task_weights, load_log_file=tuning_records, callbacks=callbacks
+ tasks, task_weights, load_log_file=prior_records, callbacks=callbacks
)
# Tune the tasks
@@ -464,13 +621,13 @@ def schedule_tasks(
def tune_tasks(
- tasks,
- log_file,
- measure_option,
- tuner,
- trials,
- early_stopping=None,
- tuning_records=None,
+ tasks: List[autotvm.task.Task],
+ log_file: str,
+ measure_option: autotvm.measure_option,
+ tuner: str,
+ trials: int,
+ early_stopping: Optional[int] = None,
+ tuning_records: Optional[str] = None,
):
"""Tune a list of tasks and output the history to a log file.
diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index 77ba1cb..34f59aa 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -59,6 +59,7 @@ def convert_graph_layout(mod, desired_layout):
# conv2d as heavily-sensitive operators.
desired_layouts = {
"nn.conv2d": [desired_layout, "default"],
+ "nn.conv2d_transpose": [desired_layout, "default"],
"qnn.conv2d": [desired_layout, "default"],
}
@@ -99,8 +100,8 @@ def validate_targets(parse_targets):
if len(tvm_targets) > 1:
verbose_tvm_targets = ", ".join(tvm_targets)
raise TVMCException(
- f"Only one of the following targets can be used at a time. "
- "Found: {verbose_tvm_targets}."
+ "Only one of the following targets can be used at a time. "
+ f"Found: {verbose_tvm_targets}."
)
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index 6884c30..3f1d04a 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -19,17 +19,16 @@ Provides support to compile networks both AOT and JIT.
"""
import logging
import os.path
-import tarfile
+from typing import Optional, Dict, List, Union, Callable
from pathlib import Path
import tvm
from tvm import autotvm, auto_scheduler
-from tvm import relay, runtime
-from tvm.contrib import cc
-from tvm.contrib import utils
+from tvm import relay
from tvm.target import Target
from . import common, composite_target, frontends
+from .model import TVMCModel, TVMCPackage
from .main import register_parser
@@ -96,7 +95,7 @@ def add_compile_parser(subparsers):
default=None,
)
parser.add_argument(
- "--disable-pass",
+ "--disabled-pass",
help="disable specific passes, comma-separated list of pass names",
type=common.parse_pass_list_str,
default="",
@@ -117,35 +116,36 @@ def drive_compile(args):
Zero if successfully completed
"""
- mod, params = frontends.load_model(args.FILE, args.model_format,
args.input_shapes)
+ tvmc_model = frontends.load_model(args.FILE, args.model_format,
args.input_shapes)
- graph, lib, params, dumps = compile_model(
- mod,
- params,
+ dump_code = [x.strip() for x in args.dump_code.split(",")] if
args.dump_code else None
+
+ compile_model(
+ tvmc_model,
args.target,
- args.dump_code,
- None,
- args.tuning_records,
- args.desired_layout,
- args.disable_pass,
+ tuning_records=args.tuning_records,
+ package_path=args.output,
+ cross=args.cross_compiler,
+ dump_code=dump_code,
+ target_host=None,
+ desired_layout=args.desired_layout,
+ disabled_pass=args.disabled_pass,
)
- if dumps:
- save_dumps(args.output, dumps)
-
- save_module(args.output, graph, lib, params, args.cross_compiler)
return 0
def compile_model(
- mod,
- params,
- target,
- dump_code=None,
- target_host=None,
- tuning_records=None,
- alter_layout=None,
- disabled_pass=None,
+ tvmc_model: TVMCModel,
+ target: str,
+ tuning_records: Optional[str] = None,
+ package_path: Optional[str] = None,
+ cross: Optional[Union[str, Callable]] = None,
+ export_format: str = "so",
+ dump_code: Optional[List[str]] = None,
+ target_host: Optional[str] = None,
+ desired_layout: Optional[str] = None,
+ disabled_pass: Optional[str] = None,
):
"""Compile a model from a supported framework into a TVM module.
@@ -155,23 +155,29 @@ def compile_model(
Parameters
----------
- mod: IRModule
- The relay module to be compiled.
- params: dict
- A dictionary containing the module's parameters.
+ tvmc_model : TVMCModel
+ The model object that should be compiled.
target : str
The target for which to compile. Can be a plain string or
a path.
+ tuning_records : str
+ A path to tuning records produced using tvmc.tune. When provided,
+ compilation will use more optimized kernels leading to better results.
+ package_path : str, optional
+ The path to export the compiled model to. If not provided it will
+ be saved in a temporary directory.
+ cross : str or callable object, optional
+ Function that performs the actual compilation
+ export_format : str
+ What format to use when saving the function library. Must be one of
"so" or "tar".
+ When compiling for a remote device without a cross compiler, "tar"
will likely work better.
dump_code : list, optional
Dump the generated code for the specified source types, on
the requested target.
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
- tuning_records: str, optional
- Path to the file produced by the tuning to be used during
- compilation.
- alter_layout: str, optional
+ desired_layout: str, optional
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
@@ -182,24 +188,18 @@ def compile_model(
Returns
-------
- graph : str
- A JSON-serialized TVM execution graph.
- lib : tvm.module.Module
- A TVM module containing the compiled functions.
- params : dict
- The parameters (weights) for the TVM module.
- dumps : dict
- Dictionary containing the dumps specified.
+ compiled_model : TVMCPackage
+ The compiled TVMCModel ready to be run.
"""
- dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else
None
+ mod, params = tvmc_model.mod, tvmc_model.params
+
config = {}
- if alter_layout:
- mod = common.convert_graph_layout(mod, alter_layout)
+ if desired_layout:
+ mod = common.convert_graph_layout(mod, desired_layout)
tvm_target, extra_targets = common.target_from_cli(target)
- target_host = tvm_target if not target_host else target_host
tvm_target, target_host = Target.check_and_update_host_consist(tvm_target,
target_host)
for codegen_from_cli in extra_targets:
@@ -225,21 +225,24 @@ def compile_model(
opt_level=3, config=config, disabled_pass=disabled_pass
):
logger.debug("building relay graph with autoscheduler")
- graph_module = relay.build(mod, target=target,
params=params)
+ graph_module = relay.build(mod, target=tvm_target,
params=params)
else:
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(
opt_level=3, config=config, disabled_pass=disabled_pass
):
logger.debug("building relay graph with tuning records")
- graph_module = relay.build(mod, tvm_target, params=params)
+ graph_module = relay.build(mod, target=tvm_target,
params=params)
else:
with tvm.transform.PassContext(opt_level=3, config=config,
disabled_pass=disabled_pass):
logger.debug("building relay graph (no tuning records provided)")
- graph_module = relay.build(mod, tvm_target, params=params)
+ graph_module = relay.build(mod, target=tvm_target, params=params)
# Generate output dump files with sources
- dump_code = dump_code or []
+ if dump_code is None:
+ dump_code = []
+ if not isinstance(dump_code, list):
+ dump_code = [dump_code]
dumps = {}
for source_type in dump_code:
lib = graph_module.get_lib()
@@ -248,59 +251,17 @@ def compile_model(
source = str(mod) if source_type == "relay" else
lib.get_source(source_type)
dumps[source_type] = source
- # TODO we need to update this return to use the updated graph module APIs
- # as these getter functions will be deprecated in the next release
(@leandron)
- return graph_module.get_json(), graph_module.get_lib(),
graph_module.get_params(), dumps
-
+ # Create a new tvmc model package object from the graph definition.
+ package_path = tvmc_model.export_package(graph_module, package_path,
cross, export_format)
-def save_module(module_path, graph, lib, params, cross=None):
- """
- Create a tarball containing the generated TVM graph,
- exported library and parameters
-
- Parameters
- ----------
- module_path : str
- path to the target tar.gz file to be created,
- including the file name
- graph : str
- A JSON-serialized TVM execution graph.
- lib : tvm.module.Module
- A TVM module containing the compiled functions.
- params : dict
- The parameters (weights) for the TVM module.
- cross : str or callable object, optional
- Function that performs the actual compilation
-
- """
- lib_name = "mod.so"
- graph_name = "mod.json"
- param_name = "mod.params"
- temp = utils.tempdir()
- path_lib = temp.relpath(lib_name)
- if not cross:
- logger.debug("exporting library to %s", path_lib)
- lib.export_library(path_lib)
- else:
- logger.debug("exporting library to %s , using cross compiler %s",
path_lib, cross)
- lib.export_library(path_lib, cc.cross_compiler(cross))
-
- with open(temp.relpath(graph_name), "w") as graph_file:
- logger.debug("writing graph to file to %s", graph_file.name)
- graph_file.write(graph)
-
- with open(temp.relpath(param_name), "wb") as params_file:
- logger.debug("writing params to file to %s", params_file.name)
- params_file.write(runtime.save_param_dict(params))
+ # Write dumps to file.
+ if dumps:
+ save_dumps(package_path, dumps)
- logger.debug("saving module as tar file to %s", module_path)
- with tarfile.open(module_path, "w") as tar:
- tar.add(path_lib, lib_name)
- tar.add(temp.relpath(graph_name), graph_name)
- tar.add(temp.relpath(param_name), param_name)
+ return TVMCPackage(package_path)
-def save_dumps(module_name, dumps, dump_root="."):
+def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."):
"""
Serialize dump files to the disk.
@@ -313,7 +274,6 @@ def save_dumps(module_name, dumps, dump_root="."):
The output contents to be saved into the files
dump_root : str, optional
Path in which dump files will be created
-
"""
for dump_format in dumps:
diff --git a/python/tvm/driver/tvmc/frontends.py
b/python/tvm/driver/tvmc/frontends.py
index 0488223..89ca1b8 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -25,12 +25,14 @@ import os
import sys
from abc import ABC
from abc import abstractmethod
+from typing import Optional, List, Dict
from pathlib import Path
import numpy as np
from tvm import relay
from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc.model import TVMCModel
# pylint: disable=invalid-name
@@ -284,7 +286,7 @@ def get_frontend_names():
return [frontend.name() for frontend in ALL_FRONTENDS]
-def get_frontend_by_name(name):
+def get_frontend_by_name(name: str):
"""
This function will try to get a frontend instance, based
on the name provided.
@@ -311,7 +313,7 @@ def get_frontend_by_name(name):
)
-def guess_frontend(path):
+def guess_frontend(path: str):
"""
This function will try to imply which framework is being used,
based on the extension of the file provided in the path parameter.
@@ -340,7 +342,12 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify
--model-format")
-def load_model(path, model_format=None, shape_dict=None, **kwargs):
+def load_model(
+ path: str,
+ model_format: Optional[str] = None,
+ shape_dict: Optional[Dict[str, List[int]]] = None,
+ **kwargs,
+):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.
@@ -356,10 +363,8 @@ def load_model(path, model_format=None, shape_dict=None,
**kwargs):
Returns
-------
- mod : tvm.relay.Module
- The produced relay module.
- params : dict
- The parameters (weights) for the relay module.
+ tvmc_model : TVMCModel
+ The produced model package.
"""
@@ -370,4 +375,4 @@ def load_model(path, model_format=None, shape_dict=None,
**kwargs):
mod, params = frontend.load(path, shape_dict, **kwargs)
- return mod, params
+ return TVMCModel(mod, params)
diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py
new file mode 100644
index 0000000..a26a47c
--- /dev/null
+++ b/python/tvm/driver/tvmc/model.py
@@ -0,0 +1,364 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This file contains the definition of a set of classes that wrap the outputs
+of TVMC functions to create a simpler and more intuitive API.
+
+There is one class for each required stage of a TVM workflow.
+The TVMCModel represents the result of importing a model into TVM, it
+contains the precompiled graph definition and parameters that define
+what the model does.
+
+Compiling a TVMCModel produces a TVMCPackage, which contains the generated
+artifacts that allow the model to be run on the target hardware.
+
+Running a TVMCPackage produces a TVMCResult, which contains the outputs of
+the model and the measured runtime.
+
+Examples
+--------
+The following code shows a full lifecycle for a model using tvmc, first the
+model is imported from an exterior framework, in this case onnx, then it
+is tuned to find the best schedules on CPU, then compiled into a TVMCPackage,
+and finally run.
+
+.. code-block:: python
+ tvmc_model = tvmc.load("my_model.onnx")
+ tuning_records = tvmc.tune(tvmc_model, target="llvm")
+ tvmc_package = tvmc.compile(tvmc_model, target="llvm",
tuning_records=tuning_records)
+ result = tvmc.run(tvmc_package, device="cpu")
+ print(result)
+"""
+import os
+import tarfile
+from typing import Optional, Union, List, Dict, Callable, TextIO
+import numpy as np
+
+import tvm
+import tvm.contrib.cc
+from tvm import relay
+from tvm.contrib import utils
+from tvm.relay.backend.graph_executor_factory import GraphExecutorFactoryModule
+
+from .common import TVMCException
+
+
+class TVMCModel(object):
+ """Initialize a TVMC model from a relay model definition or a saved file.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule, optional
+ The relay module corresponding to this model.
+ params : dict, optional
+ A parameter dictionary for the model.
+ model_path: str, optional
+ An alternative way to load a TVMCModel, the path to a previously
+ saved model.
+ """
+
+ def __init__(
+ self,
+ mod: Optional[tvm.IRModule] = None,
+ params: Optional[Dict[str, tvm.nd.NDArray]] = None,
+ model_path: Optional[str] = None,
+ ):
+ if (mod is None or params is None) and (model_path is None):
+ raise TVMCException(
+ "Either mod and params must be provided "
+ "or a path to a previously saved TVMCModel"
+ )
+ self._tmp_dir = utils.tempdir()
+ if model_path is not None:
+ self.load(model_path)
+ else:
+ self.mod = mod
+ self.params = params if params else {}
+
+ def save(self, model_path: str):
+ """Save the TVMCModel to disk.
+
+ Note that this saves the graph representation,
+ the parameters, and the tuning records if applicable. It will not save
any
+ compiled artifacts.
+
+ Parameters
+ ----------
+ model_path : str
+ A full path to save this TVMCModel to including the output file
name.
+ The file will be saved as a tar file so using a ".tar" extension
is advised.
+ """
+ temp = self._tmp_dir
+
+ # Save relay graph
+ relay_name = "model.json"
+ relay_path = temp.relpath(relay_name)
+ with open(relay_path, "w") as relay_file:
+ relay_file.write(tvm.ir.save_json(self.mod))
+
+ # Save params
+ params_name = "model.params"
+ params_path = temp.relpath(params_name)
+ with open(params_path, "wb") as params_file:
+ params_file.write(relay.save_param_dict(self.params))
+
+ # Create a tar file.
+ with tarfile.open(model_path, "w") as tar:
+ tar.add(relay_path, relay_name)
+ tar.add(params_path, params_name)
+ # If default tuning records exist, save them as well.
+ if os.path.exists(self.default_tuning_records_path()):
+ tar.add(self.default_tuning_records_path(), "tuning_records")
+ # Also save the compiled package if it can be found.
+ if os.path.exists(self.default_package_path()):
+ tar.add(self.default_package_path(), "model_package.tar")
+
+ def load(self, model_path: str):
+ """Load a TVMCModel from disk.
+
+ Parameters
+ ----------
+ model_path : str
+ A path to load the TVMCModel from.
+ """
+ temp = self._tmp_dir
+ t = tarfile.open(model_path)
+ t.extractall(temp.relpath("."))
+
+ # Load relay IR.
+ relay_path = temp.relpath("model.json")
+ with open(relay_path, "r") as relay_file:
+ self.mod = tvm.ir.load_json(relay_file.read())
+
+ # Load parameter dictionary.
+ params_path = temp.relpath("model.params")
+ with open(params_path, "rb") as params_file:
+ self.params = relay.load_param_dict(params_file.read())
+
+ def default_tuning_records_path(self):
+ """Get a full path for storing tuning records in this model's
temporary direcotry
+
+ Note that when this path is used, the tuning records will be saved and
loaded
+ when calling `save` and `load`.
+
+ Returns
+ -------
+ records_path: str
+ A path to the default location for tuning records.
+ """
+ return self._tmp_dir.relpath("tuning_records")
+
+ def default_package_path(self):
+ """Get a full path for storing a compiled package in this model's
temporary direcotry
+
+ Note that when this path is used, the package will be saved and loaded
+ when calling `save` and `load`.
+
+ Returns
+ -------
+ records_path: str
+ A path to the default location for tuning records.
+ """
+ return self._tmp_dir.relpath("model_package.tar")
+
+ def export_package(
+ self,
+ executor_factory: GraphExecutorFactoryModule,
+ package_path: Optional[str] = None,
+ cross: Optional[Union[str, Callable]] = None,
+ lib_format: str = "so",
+ ):
+ """Save this TVMCModel to file.
+ Parameters
+ ----------
+ executor_factory : GraphExecutorFactoryModule
+ The factory containing compiled the compiled artifacts needed to
run this model.
+ package_path : str, None
+ Where the model should be saved. Note that it will be packaged as
a .tar file.
+ If not provided, the package will be saved to a generically named
file in tmp.
+ cross : str or callable object, optional
+ Function that performs the actual compilation.
+ lib_format : str
+ How to export the modules function library. Must be one of "so" or
"tar".
+
+ Returns
+ -------
+ package_path : str
+ The path that the package was saved to.
+ """
+ if lib_format not in ["so", "tar"]:
+ raise TVMCException("Only .so and .tar export formats are
supported.")
+ lib_name = "mod." + lib_format
+ graph_name = "mod.json"
+ param_name = "mod.params"
+
+ temp = self._tmp_dir
+ if package_path is None:
+ package_path = self.default_package_path()
+ path_lib = temp.relpath(lib_name)
+
+ if not cross:
+ executor_factory.get_lib().export_library(path_lib)
+ else:
+ executor_factory.get_lib().export_library(
+ path_lib, tvm.contrib.cc.cross_compiler(cross)
+ )
+ self.lib_path = path_lib
+
+ with open(temp.relpath(graph_name), "w") as graph_file:
+ graph_file.write(executor_factory.get_json())
+
+ with open(temp.relpath(param_name), "wb") as params_file:
+
params_file.write(relay.save_param_dict(executor_factory.get_params()))
+
+ # Package up all the temp files into a tar file.
+ with tarfile.open(package_path, "w") as tar:
+ tar.add(path_lib, lib_name)
+ tar.add(temp.relpath(graph_name), graph_name)
+ tar.add(temp.relpath(param_name), param_name)
+
+ return package_path
+
+ def summary(self, file: TextIO = None):
+ """Print the IR corressponding to this model.
+
+ Arguments
+ ---------
+ file: Writable, optional
+ If specified, the summary will be written to this file.
+ """
+ print(self.mod, file=file)
+
+
+class TVMCPackage(object):
+ """Load a saved TVMCPackage from disk.
+
+ Parameters
+ ----------
+ package_path : str
+ The path to the saved TVMCPackage that will be loaded.
+ """
+
+ def __init__(self, package_path: str):
+ self._tmp_dir = utils.tempdir()
+ self.package_path = package_path
+ self.import_package(self.package_path)
+
+ def import_package(self, package_path: str):
+ """Load a TVMCPackage from a previously exported TVMCModel.
+
+ Parameters
+ ----------
+ package_path : str
+ The path to the saved TVMCPackage.
+ """
+ lib_name_so = "mod.so"
+ lib_name_tar = "mod.tar"
+ graph_name = "mod.json"
+ param_name = "mod.params"
+
+ temp = self._tmp_dir
+ t = tarfile.open(package_path)
+ t.extractall(temp.relpath("."))
+
+ with open(temp.relpath(param_name), "rb") as param_file:
+ self.params = bytearray(param_file.read())
+ self.graph = open(temp.relpath(graph_name)).read()
+ if os.path.exists(temp.relpath(lib_name_so)):
+ self.lib_name = lib_name_so
+ elif os.path.exists(temp.relpath(lib_name_tar)):
+ self.lib_name = lib_name_tar
+ else:
+ raise TVMCException("Couldn't find exported library in the
package.")
+ self.lib_path = temp.relpath(self.lib_name)
+
+
+class TVMCResult(object):
+ """A class that stores the results of tvmc.run and provides helper
utilities."""
+
+ def __init__(self, outputs: Dict[str, np.ndarray], times: List[str]):
+ """Create a convenience wrapper around the output of tvmc.run
+
+ Parameters
+ ----------
+ outputs : dict
+ Outputs dictionary mapping the name of the output to its numpy
value.
+ times : list of float
+ The execution times measured by the time evaluator in seconds to
produce outputs.
+ """
+ self.outputs = outputs
+ self.times = times
+
+ def format_times(self):
+ """Format the mean, max, min and std of the execution times.
+
+ This has the effect of producing a small table that looks like:
+ .. code-block::
+ Execution time summary:
+ mean (ms) max (ms) min (ms) std (ms)
+ 0.14310 0.16161 0.12933 0.01004
+
+ Returns
+ -------
+ str
+ A formatted string containing the statistics.
+ """
+
+ # timestamps
+ mean_ts = np.mean(self.times) * 1000
+ std_ts = np.std(self.times) * 1000
+ max_ts = np.max(self.times) * 1000
+ min_ts = np.min(self.times) * 1000
+
+ header = "Execution time summary:\n{0:^10} {1:^10} {2:^10}
{3:^10}".format(
+ "mean (ms)", "max (ms)", "min (ms)", "std (ms)"
+ )
+ stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format(
+ mean_ts, max_ts, min_ts, std_ts
+ )
+
+ return "%s\n%s\n" % (header, stats)
+
+ def get_output(self, name: str):
+ """A helper function to grab one of the outputs by name.
+
+ Parameters
+ ----------
+ name : str
+ The name of the output to return
+
+ Returns
+ -------
+ output : np.ndarray
+ The output corresponding to name.
+ """
+ return self.outputs[name]
+
+ def save(self, output_path: str):
+ """Save the numpy outputs to disk as a .npz file.
+
+ Parameters
+ ----------
+ output_path : str
+ The path to save the numpy results to.
+ """
+ np.savez(output_path, **self.outputs)
+
+ def __str__(self):
+ stat_table = self.format_times()
+ output_keys = f"Output Names:\n {list(self.outputs.keys())}"
+ return stat_table + "\n" + output_keys
diff --git a/python/tvm/driver/tvmc/result_utils.py
b/python/tvm/driver/tvmc/result_utils.py
new file mode 100644
index 0000000..10d3159
--- /dev/null
+++ b/python/tvm/driver/tvmc/result_utils.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This file contains utility functions for processing the outputs
+of TVMC models. These utilities are likely to be task specific,
+overtime more will be added to support more machine learning tasks.
+
+Examples
+--------
+The following code shows how one might postprocess
+the output of a classification model.
+
+.. code-block:: python
+ result = tvmc.run(tvmc_package, device="cpu")
+ top_results = result_utils.get_top_results(max_results=5)
+"""
+import numpy as np
+from .model import TVMCResult
+
+
+def get_top_results(result: TVMCResult, max_results: int):
+ """Return the top n results from the output tensor.
+
+ This function is primarily for image classification and will
+ not necessarily generalize.
+
+ Parameters
+ ----------
+ result : TVMCResult
+ The output of a TVMCModel
+ max_results : int
+ Number of results to return
+
+ Returns
+ -------
+ top_results : np.array
+ Results array of shape (2, n).
+ The first row is the indices and the second is the values.
+
+ """
+ output = np.copy(result.outputs["output_0"])
+ sorted_labels = output.argsort()[0][-max_results:][::-1]
+ output.sort()
+ sorted_values = output[0][-max_results:][::-1]
+ top_results = np.array([sorted_labels, sorted_values])
+ return top_results
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index e1a21b2..b15a16a 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -19,20 +19,21 @@ Provides support to run compiled networks both locally and
remotely.
"""
import json
import logging
-import os
-import tarfile
-import tempfile
+from typing import Optional, Dict, List, Union
import numpy as np
+import tvm
from tvm import rpc
from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_executor as runtime
from tvm.contrib.debugger import debug_executor
-from tvm.relay import load_param_dict
+from tvm.relay.param_dict import load_param_dict
from . import common
+from .model import TVMCPackage, TVMCResult
from .common import TVMCException
from .main import register_parser
+from .result_utils import get_top_results
# pylint: disable=invalid-name
@@ -82,7 +83,10 @@ def add_run_parser(subparsers):
"making it take longer to be generated.",
)
parser.add_argument(
- "--repeat", metavar="N", type=int, default=1, help="repeat the run n
times. Defaults to '1'"
+ "--repeat", metavar="N", type=int, default=1, help="run the model n
times. Defaults to '1'"
+ )
+ parser.add_argument(
+ "--number", metavar="N", type=int, default=1, help="repeat the run n
times. Defaults to '1'"
)
parser.add_argument(
"--rpc-key",
@@ -112,8 +116,10 @@ def drive_run(args):
except IOError as ex:
raise TVMCException("Error loading inputs file: %s" % ex)
- outputs, times = run_module(
- args.FILE,
+ tvmc_package = TVMCPackage(package_path=args.FILE)
+
+ result = run_module(
+ tvmc_package,
args.device,
hostname=rpc_hostname,
port=rpc_port,
@@ -121,25 +127,26 @@ def drive_run(args):
inputs=inputs,
fill_mode=args.fill_mode,
repeat=args.repeat,
+ number=args.number,
profile=args.profile,
)
if args.print_time:
- stat_table = format_times(times)
+ stat_table = result.format_times()
# print here is intentional
print(stat_table)
if args.print_top:
- top_results = get_top_results(outputs, args.print_top)
+ top_results = get_top_results(result, args.print_top)
# print here is intentional
print(top_results)
if args.outputs:
# Save the outputs
- np.savez(args.outputs, **outputs)
+ result.save(args.outputs)
-def get_input_info(graph_str, params):
+def get_input_info(graph_str: str, params: Dict[str, tvm.nd.NDArray]):
"""Return the 'shape' and 'dtype' dictionaries for the input
tensors of a compiled module.
@@ -155,8 +162,8 @@ def get_input_info(graph_str, params):
----------
graph_str : str
JSON graph of the module serialized as a string.
- params : bytearray
- Params serialized as a bytearray.
+ params : dict
+ Parameter dictionary mapping name to value.
Returns
-------
@@ -179,14 +186,14 @@ def get_input_info(graph_str, params):
shape_dict[name] = graph["attrs"]["shape"][1][node_id]
dtype_dict[name] = graph["attrs"]["dltype"][1][node_id]
- logger.debug("collecting graph input shape and type:")
- logger.debug("graph input shape: %s", shape_dict)
- logger.debug("graph input type: %s", dtype_dict)
+ logger.debug("Collecting graph input shape and type:")
+ logger.debug("Graph input shape: %s", shape_dict)
+ logger.debug("Graph input type: %s", dtype_dict)
return shape_dict, dtype_dict
-def generate_tensor_data(shape, dtype, fill_mode):
+def generate_tensor_data(shape: tuple, dtype: str, fill_mode: str):
"""Generate data to produce a tensor of given shape and dtype.
Random data generation depends on the dtype. For int8 types,
@@ -226,7 +233,12 @@ def generate_tensor_data(shape, dtype, fill_mode):
return tensor
-def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"):
+def make_inputs_dict(
+ shape_dict: Dict[str, List[int]],
+ dtype_dict: Dict[str, str],
+ inputs: Optional[Dict[str, np.ndarray]] = None,
+ fill_mode: str = "random",
+):
"""Make the inputs dictionary for a graph.
Use data from 'inputs' where specified. For input tensors
@@ -289,15 +301,16 @@ def make_inputs_dict(shape_dict, dtype_dict, inputs=None,
fill_mode="random"):
def run_module(
- module_file,
- device,
- hostname=None,
- port=9090,
- rpc_key=None,
- inputs=None,
- fill_mode="random",
- repeat=1,
- profile=False,
+ tvmc_package: TVMCPackage,
+ device: str,
+ hostname: Optional[str] = None,
+ port: Union[int, str] = 9090,
+ rpc_key: Optional[str] = None,
+ inputs: Optional[Dict[str, np.ndarray]] = None,
+ fill_mode: str = "random",
+ repeat: int = 10,
+ number: int = 10,
+ profile: bool = False,
):
"""Run a compiled graph executor module locally or remotely with
optional input values.
@@ -307,8 +320,8 @@ def run_module(
Parameters
----------
- module_file : str
- The path to the module file (a .tar file).
+ tvmc_package: TVMCPackage
+ The compiled model package object that will be run.
device: str,
the device (e.g. "cpu" or "gpu") to be targeted by the RPC
session, local or remote).
@@ -320,13 +333,16 @@ def run_module(
The tracker key of the target device. If this is set, it
will be assumed that remote points to a tracker.
inputs : dict, optional
- A dictionary that maps input names to numpy values.
+ A dictionary that maps input names to numpy values. If not provided,
+ inputs will be generated using the fill_mode argument.
fill_mode : str, optional
The fill-mode to use when generating data for input tensors.
Valid options are "zeros", "ones" and "random".
Defaults to "random".
repeat : int, optional
How many times to repeat the run.
+ number : int, optional
+ The number of runs to measure within each repeat.
profile : bool
Whether to profile the run with the debug runtime.
@@ -337,135 +353,73 @@ def run_module(
times : list of str
execution times generated by the time evaluator
"""
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- logger.debug("extracting module file %s", module_file)
- t = tarfile.open(module_file)
- t.extractall(tmp_dir)
- graph = open(os.path.join(tmp_dir, "mod.json")).read()
- params = bytearray(open(os.path.join(tmp_dir, "mod.params"),
"rb").read())
-
- if hostname:
- # Remote RPC
- if rpc_key:
- logger.debug("running on remote RPC tracker with key %s",
rpc_key)
- session = request_remote(rpc_key, hostname, port, timeout=1000)
- else:
- logger.debug("running on remote RPC with no key")
- session = rpc.connect(hostname, port)
- else:
- # Local
- logger.debug("running a local session")
- session = rpc.LocalSession()
-
- session.upload(os.path.join(tmp_dir, "mod.so"))
- lib = session.load_module("mod.so")
-
- # TODO expand to other supported devices, as listed in tvm.rpc.client
(@leandron)
- logger.debug("device is %s", device)
- if device == "gpu":
- dev = session.gpu()
- elif device == "cl":
- dev = session.cl()
- else:
- assert device == "cpu"
- dev = session.cpu()
-
- if profile:
- logger.debug("creating runtime with profiling enabled")
- module = debug_executor.create(graph, lib, dev, dump_root="./prof")
+ if not isinstance(tvmc_package, TVMCPackage):
+ raise TVMCException(
+ "This model doesn't seem to have been compiled yet. "
+ "Try calling tvmc.compile on the model before running it."
+ )
+
+ if hostname:
+ if isinstance(port, str):
+ port = int(port)
+ # Remote RPC
+ if rpc_key:
+ logger.debug("Running on remote RPC tracker with key %s.", rpc_key)
+ session = request_remote(rpc_key, hostname, port, timeout=1000)
else:
- logger.debug("creating runtime with profiling disabled")
- module = runtime.create(graph, lib, dev)
-
- logger.debug("load params into the runtime module")
- module.load_params(params)
-
- shape_dict, dtype_dict = get_input_info(graph, params)
- inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs,
fill_mode)
-
- logger.debug("setting inputs to the module")
- module.set_input(**inputs_dict)
-
- # Run must be called explicitly if profiling
- if profile:
- logger.debug("running the module with profiling enabled")
- module.run()
-
- # create the module time evaluator (returns a function)
- timer = module.module.time_evaluator("run", dev, 1, repeat=repeat)
- # call the evaluator function to invoke the module and save execution
times
- prof_result = timer()
- # collect a list of execution times from the profiling results
- times = prof_result.results
-
- logger.debug("collecting the output tensors")
- num_outputs = module.get_num_outputs()
- outputs = {}
- for i in range(num_outputs):
- output_name = "output_{}".format(i)
- outputs[output_name] = module.get_output(i).asnumpy()
-
- return outputs, times
-
-
-def get_top_results(outputs, max_results):
- """Return the top n results from the output tensor.
-
- This function is primarily for image classification and will
- not necessarily generalise.
-
- Parameters
- ----------
- outputs : dict
- Outputs dictionary - {output_name: np.array}.
- max_results : int
- Number of results to return
-
- Returns
- -------
- top_results : np.array
- Results array of shape (2, n).
- The first row is the indices and the second is the values.
-
- """
- output = np.copy(outputs["output_0"])
- sorted_labels = output.argsort()[0][-max_results:][::-1]
- output.sort()
- sorted_values = output[0][-max_results:][::-1]
- top_results = np.array([sorted_labels, sorted_values])
- return top_results
-
-
-def format_times(times):
- """Format the mean, max, min and std of the execution times.
-
- This has the effect of producing a small table that looks like:
-
- Execution time summary:
- mean (ms) max (ms) min (ms) std (ms)
- 0.14310 0.16161 0.12933 0.01004
-
- Parameters
- ----------
- times : list
- A list of execution times (in seconds).
-
- Returns
- -------
- str
- A formatted string containing the statistics.
- """
-
- # timestamps
- mean_ts = np.mean(times) * 1000
- std_ts = np.std(times) * 1000
- max_ts = np.max(times) * 1000
- min_ts = np.min(times) * 1000
-
- header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format(
- "mean (ms)", "max (ms)", "min (ms)", "std (ms)"
- )
- stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format(mean_ts,
max_ts, min_ts, std_ts)
+ logger.debug("Running on remote RPC with no key.")
+ session = rpc.connect(hostname, port)
+ else:
+ # Local
+ logger.debug("Running a local session.")
+ session = rpc.LocalSession()
+
+ session.upload(tvmc_package.lib_path)
+ lib = session.load_module(tvmc_package.lib_name)
+
+ # TODO expand to other supported devices, as listed in tvm.rpc.client
(@leandron)
+ logger.debug("Device is %s.", device)
+ if device == "gpu":
+ dev = session.gpu()
+ elif device == "cl":
+ dev = session.cl()
+ else:
+ assert device == "cpu"
+ dev = session.cpu()
- return "%s\n%s\n" % (header, stats)
+ if profile:
+ logger.debug("Creating runtime with profiling enabled.")
+ module = debug_executor.create(tvmc_package.graph, lib, dev,
dump_root="./prof")
+ else:
+ logger.debug("Creating runtime with profiling disabled.")
+ module = runtime.create(tvmc_package.graph, lib, dev)
+
+ logger.debug("Loading params into the runtime module.")
+ module.load_params(tvmc_package.params)
+
+ shape_dict, dtype_dict = get_input_info(tvmc_package.graph,
tvmc_package.params)
+ inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)
+
+ logger.debug("Setting inputs to the module.")
+ module.set_input(**inputs_dict)
+
+ # Run must be called explicitly if profiling
+ if profile:
+ logger.info("Running the module with profiling enabled.")
+ module.run()
+
+ # create the module time evaluator (returns a function)
+ timer = module.module.time_evaluator("run", dev, number=number,
repeat=repeat)
+ # call the evaluator function to invoke the module and save execution times
+ prof_result = timer()
+ # collect a list of execution times from the profiling results
+ times = prof_result.results
+
+ logger.debug("Collecting the output tensors.")
+ num_outputs = module.get_num_outputs()
+ outputs = {}
+ for i in range(num_outputs):
+ output_name = "output_{}".format(i)
+ outputs[output_name] = module.get_output(i).asnumpy()
+
+ return TVMCResult(outputs, times)
diff --git a/src/auto_scheduler/search_task.cc
b/src/auto_scheduler/search_task.cc
index ffdfbcc..80fb71d 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -167,6 +167,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
max_threads_per_block, max_vthread_extent,
warp_size);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.GetDefaultHardwareParams")
+ .set_body_typed([](Target target, Target target_host) {
+ return HardwareParamsNode::GetDefaultHardwareParams(target, target_host);
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
.set_body_typed([](ComputeDAG compute_dag, String workload_key, Target
target,
Target target_host, Optional<HardwareParams>
hardware_params,
diff --git a/tests/python/driver/tvmc/conftest.py
b/tests/python/driver/tvmc/conftest.py
index 3345b4f..f7cbf92 100644
--- a/tests/python/driver/tvmc/conftest.py
+++ b/tests/python/driver/tvmc/conftest.py
@@ -41,7 +41,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir):
return os.path.join(temp_dir, model_sub_path)
-def get_sample_compiled_module(target_dir):
+def get_sample_compiled_module(target_dir, package_filename):
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
@@ -51,8 +51,10 @@ def get_sample_compiled_module(target_dir):
temp_dir=target_dir,
)
- mod, params = tvmc.frontends.load_model(model_file)
- return tvmc.compiler.compile_model(mod, params, target="llvm")
+ tvmc_model = tvmc.frontends.load_model(model_file)
+ return tvmc.compiler.compile_model(
+ tvmc_model, target="llvm", package_path=os.path.join(target_dir,
package_filename)
+ )
# PyTest fixtures
@@ -101,6 +103,29 @@ def keras_resnet50(tmpdir_factory):
@pytest.fixture(scope="session")
+def keras_simple(tmpdir_factory):
+ try:
+ from tensorflow import keras
+ except ImportError:
+ # not all environments provide TensorFlow, so skip this fixture
+ # if that is that case.
+ return ""
+
+ model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"),
"simple_conv.h5")
+ model = keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=[32, 32, 3], batch_size=1),
+ keras.layers.Conv2D(8, kernel_size=(3, 3)),
+ keras.layers.Flatten(),
+ keras.layers.Dense(64),
+ ]
+ )
+ model.save(model_file_name)
+
+ return model_file_name
+
+
[email protected](scope="session")
def pytorch_resnet18(tmpdir_factory):
try:
import torch
@@ -129,7 +154,18 @@ def onnx_resnet50():
@pytest.fixture(scope="session")
-def tflite_compiled_module_as_tarfile(tmpdir_factory):
+def onnx_mnist():
+ base_url =
"https://github.com/onnx/models/raw/master/vision/classification/mnist/model"
+ file_to_download = "mnist-1.onnx"
+ model_file = download_testdata(
+ "{}/{}".format(base_url, file_to_download), file_to_download,
module=["tvmc"]
+ )
+
+ return model_file
+
+
[email protected](scope="session")
+def tflite_compiled_model(tmpdir_factory):
# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
@@ -143,12 +179,7 @@ def tflite_compiled_module_as_tarfile(tmpdir_factory):
return ""
target_dir = tmpdir_factory.mktemp("data")
- graph, lib, params, _ = get_sample_compiled_module(target_dir)
-
- module_file = os.path.join(target_dir, "mock.tar")
- tvmc.compiler.save_module(module_file, graph, lib, params)
-
- return module_file
+ return get_sample_compiled_module(target_dir, "mock.tar")
@pytest.fixture(scope="session")
diff --git a/tests/python/driver/tvmc/test_autoscheduler.py
b/tests/python/driver/tvmc/test_autoscheduler.py
index 25525eb..f1d750f 100644
--- a/tests/python/driver/tvmc/test_autoscheduler.py
+++ b/tests/python/driver/tvmc/test_autoscheduler.py
@@ -14,10 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
import pytest
import os
-import tarfile
from os import path
@@ -26,28 +24,30 @@ from tvm.driver import tvmc
def _get_tasks(model):
- mod, params = tvmc.frontends.load_model(model)
- tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(mod,
params, "llvm")
+ tvmc_model = tvmc.frontends.load_model(model)
+ tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(
+ tvmc_model.mod, tvmc_model.params, "llvm"
+ )
return (tasks, weights)
-def _autoscheduler_test_helper(
- model, tmpdir_name, tasks_weights=None, early_stopping=1,
tuning_records=None
-):
- tasks, weights = tasks_weights if tasks_weights else _get_tasks(model)
+def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1,
prior_records=None):
+ tvmc_model = tvmc.frontends.load_model(model)
log_file = os.path.join(tmpdir_name, "autoscheduler.json")
- tuning_options = auto_scheduler.TuningOptions(
- num_measure_trials=1,
- measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
- runner="local",
- builder="local",
- verbose=0,
+ hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm")
+
+ tvmc.tune(
+ tvmc_model,
+ target="llvm",
+ tuning_records=log_file,
+ prior_records=prior_records,
early_stopping=early_stopping,
+ enable_autoscheduler=True,
+ trials=2,
+ hardware_params=hardware_params,
)
- tvmc.autotuner.schedule_tasks(tasks[:1], weights[:1], tuning_options,
tuning_records)
-
# testing whether the log file was produced
assert path.exists(log_file), "autoscheduler log file should exist"
@@ -59,10 +59,10 @@ def _autoscheduler_test_helper(
return log_file
-def test_get_tuning_tasks(onnx_resnet50):
- pytest.importorskip("onnx")
+def test_get_tuning_tasks(keras_simple):
+ pytest.importorskip("tensorflow")
- tasks, weights = _get_tasks(onnx_resnet50)
+ tasks, weights = _get_tasks(keras_simple)
expected_task_type = auto_scheduler.SearchTask
assert type(tasks) is list
@@ -70,32 +70,25 @@ def test_get_tuning_tasks(onnx_resnet50):
assert all([type(x) is expected_task_type for x in tasks]) is True
-def test_tune_tasks(onnx_resnet50, tmpdir_factory):
- pytest.importorskip("onnx")
+def test_tune_tasks(keras_simple, tmpdir_factory):
+ pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
- _autoscheduler_test_helper(onnx_resnet50, tmpdir_name)
+ _autoscheduler_test_helper(keras_simple, tmpdir_name)
-def test_tune_tasks__tuning_records(onnx_resnet50, tmpdir_factory):
- pytest.importorskip("onnx")
+def test_tune_tasks__tuning_records(keras_simple, tmpdir_factory):
+ pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
- output_log_phase_1 = _autoscheduler_test_helper(onnx_resnet50, tmpdir_name)
+ output_log_phase_1 = _autoscheduler_test_helper(keras_simple, tmpdir_name)
# Exercises transfer learning by making sure a previous log exists
- _autoscheduler_test_helper(onnx_resnet50, tmpdir_name,
tuning_records=output_log_phase_1)
-
-
-def test_tune_tasks__no_early_stopping(onnx_resnet50, tmpdir_factory):
- pytest.importorskip("onnx")
-
- tmpdir_name = tmpdir_factory.mktemp("data")
- _autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None,
early_stopping=None)
+ _autoscheduler_test_helper(keras_simple, tmpdir_name,
prior_records=output_log_phase_1)
-def test_tune_tasks__no_tuning_records(onnx_resnet50, tmpdir_factory):
- pytest.importorskip("onnx")
+def test_tune_tasks__no_early_stopping(keras_simple, tmpdir_factory):
+ pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
- _autoscheduler_test_helper(onnx_resnet50, tmpdir_name, tasks_weights=None,
tuning_records=None)
+ _autoscheduler_test_helper(keras_simple, tmpdir_name, early_stopping=None)
diff --git a/tests/python/driver/tvmc/test_autotuner.py
b/tests/python/driver/tvmc/test_autotuner.py
index 5ce4ca9..e82e33b 100644
--- a/tests/python/driver/tvmc/test_autotuner.py
+++ b/tests/python/driver/tvmc/test_autotuner.py
@@ -14,10 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
import pytest
import os
-import tarfile
from os import path
@@ -26,8 +24,8 @@ from tvm.driver import tvmc
def _get_tasks(model):
- mod, params = tvmc.frontends.load_model(model)
- return tvmc.autotuner.autotvm_get_tuning_tasks(mod, params, "llvm")
+ tvmc_model = tvmc.frontends.load_model(model)
+ return tvmc.autotuner.autotvm_get_tuning_tasks(tvmc_model.mod,
tvmc_model.params, "llvm")
def _get_measure_options():
@@ -36,20 +34,18 @@ def _get_measure_options():
)
-def _tuner_test_helper(
- model, tuner_name, tmpdir_name, tasks=None, early_stopping=1,
tuning_records=None
-):
- tasks = tasks if tasks else _get_tasks(model)
+def _tuner_test_helper(model, tuner_name, tmpdir_name, early_stopping=1,
prior_records=None):
+ tvmc_model = tvmc.frontends.load_model(model)
log_file = os.path.join(tmpdir_name, "log_{}.txt".format(tuner_name))
- tvmc.autotuner.tune_tasks(
- tasks=[tasks[0]],
- log_file=log_file,
- measure_option=_get_measure_options(),
+ tvmc.tune(
+ tvmc_model,
+ target="llvm",
+ tuning_records=log_file,
+ prior_records=prior_records,
tuner=tuner_name,
- trials=1,
+ trials=4,
early_stopping=early_stopping,
- tuning_records=tuning_records,
)
# testing whether the log file was produced
@@ -63,10 +59,10 @@ def _tuner_test_helper(
return log_file
-def test_get_tuning_tasks(onnx_resnet50):
+def test_get_tuning_tasks(onnx_mnist):
pytest.importorskip("onnx")
- sut = _get_tasks(onnx_resnet50)
+ sut = _get_tasks(onnx_mnist)
expected_task_type = autotvm.task.Task
assert type(sut) is list
@@ -74,76 +70,85 @@ def test_get_tuning_tasks(onnx_resnet50):
assert all([type(x) is expected_task_type for x in sut]) is True
-def test_tune_tasks__tuner__xgb(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__xgb(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name)
+ _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name)
-def test_tune_tasks__tuner__xgb_knob(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__xgb_knob(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "xgb_knob", tmpdir_name)
+ _tuner_test_helper(onnx_mnist, "xgb_knob", tmpdir_name)
-def test_tune_tasks__tuner__ga(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__ga(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "ga", tmpdir_name)
+ _tuner_test_helper(onnx_mnist, "ga", tmpdir_name)
-def test_tune_tasks__tuner__random(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__random(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "random", tmpdir_name)
+ _tuner_test_helper(onnx_mnist, "random", tmpdir_name)
-def test_tune_tasks__tuner__gridsearch(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__gridsearch(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "gridsearch", tmpdir_name)
+ _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name)
-def test_tune_tasks__tuner__gridsearch__tuning_records(onnx_resnet50,
tmpdir_factory):
+def test_tune_tasks__tuner__gridsearch__tuning_records(onnx_mnist,
tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- output_log_phase_1 = _tuner_test_helper(onnx_resnet50, "gridsearch",
tmpdir_name)
+ output_log_phase_1 = _tuner_test_helper(onnx_mnist, "gridsearch",
tmpdir_name)
# Exercises transfer learning by making sure a previous log exists
- _tuner_test_helper(onnx_resnet50, "gridsearch", tmpdir_name,
tuning_records=output_log_phase_1)
+ _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name,
prior_records=output_log_phase_1)
-def test_tune_tasks__tuner__ga__empty_tasks(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__tuner__ga__empty_tasks(tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "ga", tmpdir_name, tasks=[])
+ log_file = os.path.join(tmpdir_name, "log_{}.txt".format("ga"))
+
+ tvmc.autotuner.tune_tasks(
+ tasks=[],
+ log_file=log_file,
+ measure_option=_get_measure_options(),
+ tuner="ga",
+ trials=1,
+ early_stopping=1,
+ )
-def test_tune_tasks__tuner__xgb__no_early_stopping(onnx_resnet50,
tmpdir_factory):
+def test_tune_tasks__tuner__xgb__no_early_stopping(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name, tasks=None,
early_stopping=None)
+ _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, early_stopping=None)
-def test_tune_tasks__tuner__xgb__no_tuning_records(onnx_resnet50,
tmpdir_factory):
+def test_tune_tasks__tuner__xgb__no_tuning_records(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
- _tuner_test_helper(onnx_resnet50, "xgb", tmpdir_name, tasks=None,
tuning_records=None)
+ _tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, prior_records=None)
-def test_tune_tasks__invalid_tuner(onnx_resnet50, tmpdir_factory):
+def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
- tasks = _get_tasks(onnx_resnet50)
+ tasks = _get_tasks(onnx_mnist)
log_file = os.path.join(tmpdir_factory.mktemp("data"), "log2.txt")
with pytest.raises(tvmc.common.TVMCException):
diff --git a/tests/python/driver/tvmc/test_command_line.py
b/tests/python/driver/tvmc/test_command_line.py
new file mode 100644
index 0000000..66a3216
--- /dev/null
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import os
+
+from tvm.driver.tvmc.main import _main
+
+
+def test_tvmc_cl_workflow(keras_simple, tmpdir_factory):
+ pytest.importorskip("tensorflow")
+
+ tmpdir = tmpdir_factory.mktemp("data")
+
+ # Test model tuning
+ log_path = os.path.join(tmpdir, "keras-autotuner_records.json")
+ tuning_str = (
+ f"tvmc tune --target llvm --output {log_path} "
+ f"--trials 2 --enable-autoscheduler {keras_simple}"
+ )
+ tuning_args = tuning_str.split(" ")[1:]
+ _main(tuning_args)
+ assert os.path.exists(log_path)
+
+ # Test model compilation
+ package_path = os.path.join(tmpdir, "keras-tvm.tar")
+ compile_str = (
+ f"tvmc compile --target llvm --tuning-records {log_path} "
+ f"--output {package_path} {keras_simple}"
+ )
+ compile_args = compile_str.split(" ")[1:]
+ _main(compile_args)
+ assert os.path.exists(package_path)
+
+ # Test running the model
+ output_path = os.path.join(tmpdir, "predictions.npz")
+ run_str = f"tvmc run --outputs {output_path} {package_path}"
+ run_args = run_str.split(" ")[1:]
+ _main(run_args)
+ assert os.path.exists(output_path)
diff --git a/tests/python/driver/tvmc/test_compiler.py
b/tests/python/driver/tvmc/test_compiler.py
index 24fa452..a023689 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import argparse
import os
import shutil
from os import path
@@ -28,6 +27,7 @@ from tvm.relay.op.contrib.ethosn import ethosn_available
from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver import tvmc
+from tvm.driver.tvmc.model import TVMCPackage
def test_save_dumps(tmpdir_factory):
@@ -45,16 +45,16 @@ def test_save_dumps(tmpdir_factory):
def verify_compile_tflite_module(model, shape_dict=None):
pytest.importorskip("tflite")
- mod, params = tvmc.load(model, shape_dict=shape_dict)
- graph, lib, params, dumps = tvmc.compile(
- mod, params, target="llvm", dump_code="ll", alter_layout="NCHW"
- )
+ tvmc_model = tvmc.load(model, shape_dict=shape_dict)
+ tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll",
desired_layout="NCHW")
+ dumps_path = tvmc_package.package_path + ".ll"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
@@ -75,35 +75,42 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")
- mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
- graph, lib, params, dumps = tvmc.compile(
- mod,
- params,
+ tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
+ tvmc_package = tvmc.compile(
+ tvmc_model,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr='+neon'",
dump_code="asm",
+ cross="aarch64-linux-gnu-gcc",
)
+ dumps_path = tvmc_package.package_path + ".asm"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
# some CI environments wont offer tensorflow/Keras, so skip in case it is
not present
pytest.importorskip("tensorflow")
- mod, params = tvmc.load(keras_resnet50)
- graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm",
dump_code="ll")
-
expected_temp_dir = tmpdir_factory.mktemp("saved_output")
expected_file_name = "saved.tar"
module_file = os.path.join(expected_temp_dir, expected_file_name)
- tvmc.compiler.save_module(module_file, graph, lib, params)
+
+ tvmc_model = tvmc.load(keras_resnet50)
+ tvmc.compile(tvmc_model, target="llvm", dump_code="ll",
package_path=module_file)
assert os.path.exists(module_file), "output file {0} should
exist".format(module_file)
+ # Test that we can load back in a module.
+ tvmc_package = TVMCPackage(package_path=module_file)
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.params) is bytearray
+
# This test will be skipped if the AArch64 cross-compilation toolchain is not
installed.
@pytest.mark.skipif(
@@ -113,34 +120,36 @@ def
test_cross_compile_aarch64_keras_module(keras_resnet50):
# some CI environments wont offer tensorflow/Keras, so skip in case it is
not present
pytest.importorskip("tensorflow")
- mod, params = tvmc.load(keras_resnet50)
- graph, lib, params, dumps = tvmc.compile(
- mod,
- params,
+ tvmc_model = tvmc.load(keras_resnet50)
+ tvmc_package = tvmc.compile(
+ tvmc_model,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr='+neon'",
dump_code="asm",
+ cross="aarch64-linux-gnu-gcc",
)
+ dumps_path = tvmc_package.package_path + ".asm"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
- assert "asm" in dumps.keys()
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
def verify_compile_onnx_module(model, shape_dict=None):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
- mod, params = tvmc.load(model, shape_dict=shape_dict)
- graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm",
dump_code="ll")
+ tvmc_model = tvmc.load(model, shape_dict=shape_dict)
+ tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll")
+ dumps_path = tvmc_package.package_path + ".ll"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
- assert "ll" in dumps.keys()
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
def test_compile_onnx_module(onnx_resnet50):
@@ -160,38 +169,40 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
- mod, params = tvmc.load(onnx_resnet50)
- graph, lib, params, dumps = tvmc.compile(
- mod,
- params,
+ tvmc_model = tvmc.load(onnx_resnet50)
+ tvmc_package = tvmc.compile(
+ tvmc_model,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
+ cross="aarch64-linux-gnu-gcc",
)
+ dumps_path = tvmc_package.package_path + ".asm"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
- assert "asm" in dumps.keys()
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
@tvm.testing.requires_opencl
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")
- mod, params = tvmc.load(tflite_mobilenet_v1_0_25_128)
- graph, lib, params, dumps = tvmc.compile(
- mod,
- params,
+ tvmc_model = tvmc.load(tflite_mobilenet_v1_0_25_128)
+ tvmc_package = tvmc.compile(
+ tvmc_model,
target="opencl --host=llvm",
- alter_layout="NCHW",
+ desired_layout="NCHW",
)
+ dumps_path = tvmc_package.package_path + ".asm"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
@pytest.mark.skipif(
@@ -200,16 +211,16 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
)
def
test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")
- mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
- graph, lib, params, dumps = tvmc.compile(
- mod, params, target="ethos-n77, llvm", dump_code="relay"
- )
+ tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
+ tvmc_package = tvmc.compile(tvmc_model, target="ethos-n77, llvm",
dump_code="relay")
+ dumps_path = tvmc_package.package_path + ".relay"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
@pytest.mark.skipif(
@@ -219,36 +230,38 @@ def
test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant
def
test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")
- mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
- graph, lib, params, dumps = tvmc.compiler.compile_model(
- mod,
- params,
+ tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
+ tvmc_package = tvmc.compiler.compile_model(
+ tvmc_model,
target="vitis-ai -dpu=DPUCZDX8G-zcu104
-export_runtime_module=vitis_ai.rtmod, llvm",
dump_code="relay",
)
+ dumps_path = tvmc_package.package_path + ".relay"
# check for output types
- assert type(graph) is str
- assert type(lib) is tvm.runtime.module.Module
- assert type(params) is dict
- assert type(dumps) is dict
+ assert type(tvmc_package) is TVMCPackage
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.lib_path) is str
+ assert type(tvmc_package.params) is bytearray
+ assert os.path.exists(dumps_path)
@mock.patch("tvm.relay.build")
@mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target")
@mock.patch("tvm.driver.tvmc.load")
@mock.patch("tvm.transform.PassContext")
-def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct,
mock_relay):
[email protected]("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
+def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe,
mock_ct, mock_relay):
mock_codegen = {}
mock_codegen["config_key"] = "relay.ext.mock.options"
mock_codegen["pass_pipeline"] = lambda *args, **kwargs: None
- mock_fe.return_value = (None, None)
+ mock_fe.return_value = mock.MagicMock()
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()
- mod, params = tvmc.load("no_file_needed")
- graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen
-testopt=value, llvm")
+ tvmc_model = tvmc.load("no_file_needed")
+ tvmc.compile(tvmc_model, target="mockcodegen -testopt=value, llvm")
mock_pc.assert_called_once_with(
opt_level=3,
diff --git a/tests/python/driver/tvmc/test_frontends.py
b/tests/python/driver/tvmc/test_frontends.py
index 3da63d4..adf62eb 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -23,6 +23,7 @@ from tvm.ir.module import IRModule
from tvm.driver import tvmc
from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc.model import TVMCModel
def test_get_frontends_contains_only_strings():
@@ -108,11 +109,12 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
- mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
- assert type(mod) is IRModule
- assert type(params) is dict
+ tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant)
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_model.mod) is IRModule
+ assert type(tvmc_model.params) is dict
# check whether one known value is part of the params dict
- assert "_param_1" in params.keys()
+ assert "_param_1" in tvmc_model.params.keys()
@pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}])
@@ -120,40 +122,43 @@ def test_load_model__keras(keras_resnet50,
load_model_kwargs):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is
not present
pytest.importorskip("tensorflow")
- mod, params = tvmc.frontends.load_model(keras_resnet50,
**load_model_kwargs)
- assert type(mod) is IRModule
- assert type(params) is dict
+ tvmc_model = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs)
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_model.mod) is IRModule
+ assert type(tvmc_model.params) is dict
## check whether one known value is part of the params dict
- assert "_param_1" in params.keys()
+ assert "_param_1" in tvmc_model.params.keys()
def verify_load_model__onnx(model, **kwargs):
- mod, params = tvmc.frontends.load_model(model, **kwargs)
- assert type(mod) is IRModule
- assert type(params) is dict
- return mod, params
+ tvmc_model = tvmc.frontends.load_model(model, **kwargs)
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_model.mod) is IRModule
+ assert type(tvmc_model.params) is dict
+ return tvmc_model
def test_load_model__onnx(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
- mod, params = verify_load_model__onnx(onnx_resnet50)
+ tvmc_model = verify_load_model__onnx(onnx_resnet50)
# check whether one known value is part of the params dict
- assert "resnetv24_batchnorm0_gamma" in params.keys()
- mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
+ assert "resnetv24_batchnorm0_gamma" in tvmc_model.params.keys()
+ tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been
folded into constants
- assert params == {}
+ assert tvmc_model.params == {}
def test_load_model__pb(pb_mobilenet_v1_1_quant):
# some CI environments wont offer TensorFlow, so skip in case it is not
present
pytest.importorskip("tensorflow")
- mod, params = tvmc.load(pb_mobilenet_v1_1_quant)
- assert type(mod) is IRModule
- assert type(params) is dict
+ tvmc_model = tvmc.load(pb_mobilenet_v1_1_quant)
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_model.mod) is IRModule
+ assert type(tvmc_model.params) is dict
# check whether one known value is part of the params dict
- assert "MobilenetV1/Conv2d_0/weights" in params.keys()
+ assert "MobilenetV1/Conv2d_0/weights" in tvmc_model.params.keys()
def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
@@ -188,11 +193,12 @@ def test_load_model__pth(pytorch_resnet18):
pytest.importorskip("torch")
pytest.importorskip("torchvision")
- mod, params = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224,
224]})
- assert type(mod) is IRModule
- assert type(params) is dict
+ tvmc_model = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224,
224]})
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_model.mod) is IRModule
+ assert type(tvmc_model.params) is dict
# check whether one known value is part of the params dict
- assert "layer1.0.conv1.weight" in params.keys()
+ assert "layer1.0.conv1.weight" in tvmc_model.params.keys()
def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
diff --git a/tests/python/driver/tvmc/test_model.py
b/tests/python/driver/tvmc/test_model.py
new file mode 100644
index 0000000..f5a28d4
--- /dev/null
+++ b/tests/python/driver/tvmc/test_model.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import os
+
+from os import path
+
+from tvm.driver import tvmc
+from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult
+
+
+def test_tvmc_workflow(keras_simple):
+ pytest.importorskip("tensorflow")
+
+ tvmc_model = tvmc.load(keras_simple)
+ tuning_records = tvmc.tune(tvmc_model, target="llvm",
enable_autoscheduler=True, trials=2)
+ tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records,
target="llvm")
+ result = tvmc.run(tvmc_package, device="cpu")
+ assert type(tvmc_model) is TVMCModel
+ assert type(tvmc_package) is TVMCPackage
+ assert type(result) is TVMCResult
+ assert path.exists(tuning_records)
+ assert type(result.outputs) is dict
+ assert type(result.times) is tuple
+ assert "output_0" in result.outputs.keys()
+
+
+def test_save_load_model(keras_simple, tmpdir_factory):
+ pytest.importorskip("onnx")
+
+ tmpdir = tmpdir_factory.mktemp("data")
+ tvmc_model = tvmc.load(keras_simple)
+
+ # Create tuning artifacts
+ tvmc.tune(tvmc_model, target="llvm", trials=2)
+
+ # Create package artifacts
+ tvmc.compile(tvmc_model, target="llvm")
+
+ # Save the model to disk
+ model_path = os.path.join(tmpdir, "saved_model.tar")
+ tvmc_model.save(model_path)
+
+ # Load the model into a new TVMCModel
+ new_tvmc_model = TVMCModel(model_path=model_path)
+
+ # Check that the two models match.
+ assert str(new_tvmc_model.mod) == str(tvmc_model.mod)
+ # Check that tuning records and the compiled package are recoverable.
+ assert path.exists(new_tvmc_model.default_package_path())
+ assert path.exists(new_tvmc_model.default_tuning_records_path())
diff --git a/tests/python/driver/tvmc/test_runner.py
b/tests/python/driver/tvmc/test_runner.py
index cbea7e3..5277a79 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -18,6 +18,8 @@ import pytest
import numpy as np
from tvm.driver import tvmc
+from tvm.driver.tvmc.model import TVMCResult
+from tvm.driver.tvmc.result_utils import get_top_results
def test_generate_tensor_data_zeros():
@@ -50,14 +52,16 @@ def test_generate_tensor_data__type_unknown():
def test_format_times__contains_header():
- sut = tvmc.runner.format_times([0.6, 1.2, 0.12, 0.42])
+ fake_result = TVMCResult(outputs=None, times=[0.6, 1.2, 0.12, 0.42])
+ sut = fake_result.format_times()
assert "std (ms)" in sut
def test_get_top_results_keep_results():
fake_outputs = {"output_0": np.array([[1, 2, 3, 4], [5, 6, 7, 8]])}
+ fake_result = TVMCResult(outputs=fake_outputs, times=None)
number_of_results_wanted = 3
- sut = tvmc.runner.get_top_results(fake_outputs, number_of_results_wanted)
+ sut = get_top_results(fake_result, number_of_results_wanted)
expected_number_of_lines = 2
assert len(sut) == expected_number_of_lines
@@ -67,16 +71,14 @@ def test_get_top_results_keep_results():
assert len(sut[1]) == expected_number_of_results_per_line
-def test_run_tflite_module__with_profile__valid_input(
- tflite_compiled_module_as_tarfile, imagenet_cat
-):
+def test_run_tflite_module__with_profile__valid_input(tflite_compiled_model,
imagenet_cat):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
inputs = np.load(imagenet_cat)
- outputs, times = tvmc.run(
- tflite_compiled_module_as_tarfile,
+ result = tvmc.run(
+ tflite_compiled_model,
inputs=inputs,
hostname=None,
device="cpu",
@@ -84,7 +86,7 @@ def test_run_tflite_module__with_profile__valid_input(
)
# collect the top 5 results
- top_5_results = tvmc.runner.get_top_results(outputs, 5)
+ top_5_results = get_top_results(result, 5)
top_5_ids = top_5_results[0]
# IDs were collected from this reference:
@@ -95,6 +97,6 @@ def test_run_tflite_module__with_profile__valid_input(
assert (
tiger_cat_mobilenet_id in top_5_ids
), "tiger cat is expected in the top-5 for mobilenet v1"
- assert type(outputs) is dict
- assert type(times) is tuple
- assert "output_0" in outputs.keys()
+ assert type(result.outputs) is dict
+ assert type(result.times) is tuple
+ assert "output_0" in result.outputs.keys()
diff --git a/tests/python/driver/tvmc/test_tvmc_common.py
b/tests/python/driver/tvmc/test_tvmc_common.py
index 474649d..078076b 100644
--- a/tests/python/driver/tvmc/test_tvmc_common.py
+++ b/tests/python/driver/tvmc/test_tvmc_common.py
@@ -15,13 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import argparse
-import os
-from os import path
import pytest
import tvm
-from tvm import relay
from tvm.driver import tvmc
from tvm.driver.tvmc.common import TVMCException
@@ -31,7 +28,8 @@ def
test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
- before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
+ tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
+ before = tvmc_model.mod
expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)
@@ -55,7 +53,8 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")
- before, _ = tvmc.frontends.load_model(onnx_resnet50)
+ tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
+ before = tvmc_model.mod
expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)
@@ -79,7 +78,8 @@ def
test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
- before, _ = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
+ tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
+ before = tvmc_model.mod
expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)
@@ -103,7 +103,8 @@ def
test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")
- before, _ = tvmc.frontends.load_model(onnx_resnet50)
+ tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
+ before = tvmc_model.mod
expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)