This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b0f18e [TVMC] Split common tvmc file into more specific files (#9529)
4b0f18e is described below
commit 4b0f18e433f1517e1da36bb0d6cad158fa42a9a8
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Tue Jan 4 16:18:13 2022 +0000
[TVMC] Split common tvmc file into more specific files (#9529)
This follows from #9206 and splits common.py into multiple smaller and more
focussed files.
---
python/tvm/driver/tvmc/__init__.py | 11 +-
python/tvm/driver/tvmc/arguments.py | 52 ++
python/tvm/driver/tvmc/autotuner.py | 21 +-
python/tvm/driver/tvmc/common.py | 799 ----------------------
python/tvm/driver/tvmc/compiler.py | 19 +-
python/tvm/driver/tvmc/composite_target.py | 2 +-
python/tvm/driver/tvmc/frontends.py | 3 +-
python/tvm/driver/tvmc/main.py | 3 +-
python/tvm/driver/tvmc/micro.py | 6 +-
python/tvm/driver/tvmc/model.py | 3 +-
python/tvm/driver/tvmc/pass_config.py | 122 ++++
python/tvm/driver/tvmc/pass_list.py | 54 ++
python/tvm/driver/tvmc/project.py | 233 +++++++
python/tvm/driver/tvmc/registry.py | 2 +-
python/tvm/driver/tvmc/runner.py | 11 +-
python/tvm/driver/tvmc/shape_parser.py | 67 ++
python/tvm/driver/tvmc/target.py | 278 ++++++++
python/tvm/driver/tvmc/tracker.py | 57 ++
python/tvm/driver/tvmc/transform.py | 62 ++
tests/python/driver/tvmc/test_autotuner.py | 2 +-
tests/python/driver/tvmc/test_compiler.py | 6 +-
tests/python/driver/tvmc/test_composite_target.py | 2 +-
tests/python/driver/tvmc/test_frontends.py | 13 +-
tests/python/driver/tvmc/test_pass_config.py | 16 +-
tests/python/driver/tvmc/test_pass_list.py | 8 +-
tests/python/driver/tvmc/test_registry_options.py | 2 +-
tests/python/driver/tvmc/test_runner.py | 2 +-
tests/python/driver/tvmc/test_shape_parser.py | 22 +-
tests/python/driver/tvmc/test_target.py | 43 +-
tests/python/driver/tvmc/test_target_options.py | 11 +-
tests/python/driver/tvmc/test_tracker.py | 8 +-
31 files changed, 1036 insertions(+), 904 deletions(-)
diff --git a/python/tvm/driver/tvmc/__init__.py
b/python/tvm/driver/tvmc/__init__.py
index 70747cb..24bb2bc 100644
--- a/python/tvm/driver/tvmc/__init__.py
+++ b/python/tvm/driver/tvmc/__init__.py
@@ -14,11 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=redefined-builtin
+# pylint: disable=redefined-builtin,wrong-import-position
"""
TVMC - TVM driver command-line interface
"""
+
+class TVMCException(Exception):
+ """TVMC Exception"""
+
+
+class TVMCImportError(TVMCException):
+ """TVMC TVMCImportError"""
+
+
from . import micro
from . import runner
from . import autotuner
diff --git a/python/tvm/driver/tvmc/arguments.py
b/python/tvm/driver/tvmc/arguments.py
new file mode 100644
index 0000000..57b6ee2
--- /dev/null
+++ b/python/tvm/driver/tvmc/arguments.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TVMC Argument Parsing
+"""
+
+import argparse
+
+from tvm.driver.tvmc import TVMCException
+
+
+class TVMCSuppressedArgumentParser(argparse.ArgumentParser):
+ """
+ A silent ArgumentParser class.
+ This class is meant to be used as a helper for creating dynamic parsers in
+ TVMC. It will create a "supressed" parser based on an existing one (parent)
+ which does not include a help message, does not print a usage message (even
+ when -h or --help is passed) and does not exit on invalid choice parse
+ errors but rather throws a TVMCException so it can be handled and the
+ dynamic parser construction is not interrupted prematurely.
+ """
+
+ def __init__(self, parent, **kwargs):
+ # Don't add '-h' or '--help' options to the newly created parser.
Don't print usage message.
+ # 'add_help=False' won't supress existing '-h' and '--help' options
from the parser (and its
+ # subparsers) present in 'parent'. However that class is meant to be
used with the main
+ # parser, which is created with `add_help=False` - the help is added
only later. Hence it
+ # the newly created parser won't have help options added in its (main)
root parser. The
+ # subparsers in the main parser will eventually have help activated,
which is enough for its
+ # use in TVMC.
+ super().__init__(parents=[parent], add_help=False,
usage=argparse.SUPPRESS, **kwargs)
+
+ def exit(self, status=0, message=None):
+ # Don't exit on error when parsing the command line.
+ # This won't catch all the errors generated when parsing tho. For
instance, it won't catch
+ # errors due to missing required arguments. But this will catch
"error: invalid choice",
+ # which is what it's necessary for its use in TVMC.
+ raise TVMCException()
diff --git a/python/tvm/driver/tvmc/autotuner.py
b/python/tvm/driver/tvmc/autotuner.py
index 60bec38..8f14c80 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -34,11 +34,12 @@ from tvm.autotvm.tuner import RandomTuner
from tvm.autotvm.tuner import XGBTuner
from tvm.target import Target
-from . import common, composite_target, frontends
-from .common import TVMCException
+from . import TVMCException, composite_target, frontends
from .main import register_parser
from .model import TVMCModel
-from .target import generate_target_args, reconstruct_target_args
+from .target import target_from_cli, generate_target_args,
reconstruct_target_args
+from .shape_parser import parse_shape_string
+from .transform import convert_graph_layout
# pylint: disable=invalid-name
@@ -220,7 +221,7 @@ def add_tune_parser(subparsers, _):
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
- type=common.parse_shape_string,
+ type=parse_shape_string,
)
@@ -256,9 +257,7 @@ def drive_tune(args):
logger.info("RPC tracker port: %s", rpc_port)
if not args.rpc_key:
- raise common.TVMCException(
- "need to provide an RPC tracker key (--rpc-key) for remote
tuning"
- )
+ raise TVMCException("need to provide an RPC tracker key
(--rpc-key) for remote tuning")
else:
rpc_hostname = None
rpc_port = None
@@ -376,7 +375,7 @@ def tune_model(
tuning_records : str
The path to the produced tuning log file.
"""
- target, extra_targets = common.target_from_cli(target,
additional_target_options)
+ target, extra_targets = target_from_cli(target, additional_target_options)
target, target_host = Target.check_and_update_host_consist(target,
target_host)
# TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates
source
# model is fixed. For now, creating a clone avoids the issue.
@@ -399,7 +398,7 @@ def tune_model(
if rpc_key:
if hostname is None or port is None:
- raise common.TVMCException(
+ raise TVMCException(
"You must provide a hostname and port to connect to a remote
RPC device."
)
if isinstance(port, str):
@@ -520,7 +519,7 @@ def autotvm_get_tuning_tasks(
target, target_host = Target.check_and_update_host_consist(target,
target_host)
if alter_layout:
- mod = common.convert_graph_layout(mod, alter_layout)
+ mod = convert_graph_layout(mod, alter_layout)
tasks = autotvm.task.extract_from_program(
mod["main"],
@@ -569,7 +568,7 @@ def autoscheduler_get_tuning_tasks(
target, target_host = Target.check_and_update_host_consist(target,
target_host)
if alter_layout:
- mod = common.convert_graph_layout(mod, alter_layout)
+ mod = convert_graph_layout(mod, alter_layout)
# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
deleted file mode 100644
index 498da23..0000000
--- a/python/tvm/driver/tvmc/common.py
+++ /dev/null
@@ -1,799 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Common utility functions shared by TVMC modules.
-"""
-import re
-import json
-import logging
-import os.path
-import argparse
-import pathlib
-from typing import Union
-from collections import defaultdict
-from urllib.parse import urlparse
-
-import tvm
-from tvm.driver import tvmc
-from tvm import relay
-from tvm import transform
-from tvm._ffi import registry
-from .fmtopt import format_option
-
-# pylint: disable=invalid-name
-logger = logging.getLogger("TVMC")
-
-
-class TVMCException(Exception):
- """TVMC Exception"""
-
-
-class TVMCSuppressedArgumentParser(argparse.ArgumentParser):
- """
- A silent ArgumentParser class.
-
- This class is meant to be used as a helper for creating dynamic parsers in
- TVMC. It will create a "supressed" parser based on an existing one (parent)
- which does not include a help message, does not print a usage message (even
- when -h or --help is passed) and does not exit on invalid choice parse
- errors but rather throws a TVMCException so it can be handled and the
- dynamic parser construction is not interrupted prematurely.
-
- """
-
- def __init__(self, parent, **kwargs):
- # Don't add '-h' or '--help' options to the newly created parser.
Don't print usage message.
- # 'add_help=False' won't supress existing '-h' and '--help' options
from the parser (and its
- # subparsers) present in 'parent'. However that class is meant to be
used with the main
- # parser, which is created with `add_help=False` - the help is added
only later. Hence it
- # the newly created parser won't have help options added in its (main)
root parser. The
- # subparsers in the main parser will eventually have help activated,
which is enough for its
- # use in TVMC.
- super().__init__(parents=[parent], add_help=False,
usage=argparse.SUPPRESS, **kwargs)
-
- def exit(self, status=0, message=None):
- # Don't exit on error when parsing the command line.
- # This won't catch all the errors generated when parsing tho. For
instance, it won't catch
- # errors due to missing required arguments. But this will catch
"error: invalid choice",
- # which is what it's necessary for its use in TVMC.
- raise TVMCException()
-
-
-class TVMCImportError(TVMCException):
- """TVMC TVMCImportError"""
-
-
-def convert_graph_layout(mod, desired_layout):
- """Alter the layout of the input graph.
-
- Parameters
- ----------
- mod : tvm.IRModule
- The relay module to convert.
- desired_layout : str
- The layout to convert to.
-
- Returns
- -------
- mod : tvm.IRModule
- The converted module.
- """
-
- # Assume for the time being that graphs only have
- # conv2d as heavily-sensitive operators.
- desired_layouts = {
- "nn.conv2d": [desired_layout, "default"],
- "nn.conv2d_transpose": [desired_layout, "default"],
- "qnn.conv2d": [desired_layout, "default"],
- }
-
- # Convert the layout of the graph where possible.
- seq = transform.Sequential(
- [
- relay.transform.RemoveUnusedFunctions(),
- relay.transform.ConvertLayout(desired_layouts),
- ]
- )
-
- with transform.PassContext(opt_level=3):
- try:
- return seq(mod)
- except Exception as err:
- raise TVMCException(
- "Error converting layout to {0}: {1}".format(desired_layout,
str(err))
- )
-
-
-def validate_targets(parse_targets, additional_target_options=None):
- """
- Apply a series of validations in the targets provided via CLI.
- """
- tvm_target_kinds = tvm.target.Target.list_kinds()
- targets = [t["name"] for t in parse_targets]
-
- if len(targets) > len(set(targets)):
- raise TVMCException("Duplicate target definitions are not allowed")
-
- if targets[-1] not in tvm_target_kinds:
- tvm_target_names = ", ".join(tvm_target_kinds)
- raise TVMCException(
- f"The last target needs to be a TVM target. Choices:
{tvm_target_names}"
- )
-
- tvm_targets = [t for t in targets if t in tvm_target_kinds]
- if len(tvm_targets) > 2:
- verbose_tvm_targets = ", ".join(tvm_targets)
- raise TVMCException(
- "Only two of the following targets can be used at a time. "
- f"Found: {verbose_tvm_targets}."
- )
-
- if additional_target_options is not None:
- for target_name in additional_target_options:
- if not any([target for target in parse_targets if target["name"]
== target_name]):
- first_option =
list(additional_target_options[target_name].keys())[0]
- raise TVMCException(
- f"Passed --target-{target_name}-{first_option}"
- f" but did not specify {target_name} target"
- )
-
-
-def tokenize_target(target):
- """
- Extract a list of tokens from a target specification text.
-
- It covers some corner-cases that are not covered by the built-in
- module 'shlex', such as the use of "+" as a punctuation character.
-
-
- Example
- -------
-
- For the input `foo -op1=v1 -op2="v ,2", bar -op3=v-4` we
- should obtain:
-
- ["foo", "-op1=v1", "-op2="v ,2"", ",", "bar", "-op3=v-4"]
-
- Parameters
- ----------
- target : str
- Target options sent via CLI arguments
-
- Returns
- -------
- list of str
- a list of parsed tokens extracted from the target string
- """
-
- # Regex to tokenize the "--target" value. It is split into five parts
- # to match with:
- # 1. target and option names e.g. llvm, -mattr=, -mcpu=
- # 2. option values, all together, without quotes e.g. -mattr=+foo,+opt
- # 3. option values, when single quotes are used e.g. -mattr='+foo, +opt'
- # 4. option values, when double quotes are used e.g. -mattr="+foo ,+opt"
- # 5. commas that separate different targets e.g. "my-target, llvm"
- target_pattern = (
- r"(\-{0,2}[\w\-]+\=?"
- r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*"
- r"|[\'][\w\+\-,\s\.]+[\']"
- r"|[\"][\w\+\-,\s\.]+[\"])*"
- r"|,)"
- )
-
- return re.findall(target_pattern, target)
-
-
-def parse_target(target):
- """
- Parse a plain string of targets provided via a command-line
- argument.
-
- To send more than one codegen, a comma-separated list
- is expected. Options start with -<option_name>=<value>.
-
- We use python standard library 'shlex' to parse the argument in
- a POSIX compatible way, so that if options are defined as
- strings with spaces or commas, for example, this is considered
- and parsed accordingly.
-
-
- Example
- -------
-
- For the input `--target="foo -op1=v1 -op2="v ,2", bar -op3=v-4"` we
- should obtain:
-
- [
- {
- name: "foo",
- opts: {"op1":"v1", "op2":"v ,2"},
- raw: 'foo -op1=v1 -op2="v ,2"'
- },
- {
- name: "bar",
- opts: {"op3":"v-4"},
- raw: 'bar -op3=v-4'
- }
- ]
-
- Parameters
- ----------
- target : str
- Target options sent via CLI arguments
-
- Returns
- -------
- codegens : list of dict
- This list preserves the order in which codegens were
- provided via command line. Each Dict contains three keys:
- 'name', containing the name of the codegen; 'opts' containing
- a key-value for all options passed via CLI; 'raw',
- containing the plain string for this codegen
- """
- codegen_names = tvmc.composite_target.get_codegen_names()
- codegens = []
-
- tvm_target_kinds = tvm.target.Target.list_kinds()
- parsed_tokens = tokenize_target(target)
-
- split_codegens = []
- current_codegen = []
- split_codegens.append(current_codegen)
- for token in parsed_tokens:
- # every time there is a comma separating
- # two codegen definitions, prepare for
- # a new codegen
- if token == ",":
- current_codegen = []
- split_codegens.append(current_codegen)
- else:
- # collect a new token for the current
- # codegen being parsed
- current_codegen.append(token)
-
- # at this point we have a list of lists,
- # each item on the first list is a codegen definition
- # in the comma-separated values
- for codegen_def in split_codegens:
- # the first is expected to be the name
- name = codegen_def[0]
- is_tvm_target = name in tvm_target_kinds and name not in codegen_names
- raw_target = " ".join(codegen_def)
- all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
- opts = {}
- for opt in all_opts:
- try:
- # deal with -- prefixed flags
- if opt.startswith("--"):
- opt_name = opt[2:]
- opt_value = True
- else:
- opt = opt[1:] if opt.startswith("-") else opt
- opt_name, opt_value = opt.split("=", maxsplit=1)
-
- # remove quotes from the value: quotes are only parsed if
they match,
- # so it is safe to assume that if the string starts with
quote, it ends
- # with quote.
- opt_value = opt_value[1:-1] if opt_value[0] in ('"', "'")
else opt_value
- except ValueError:
- raise ValueError(f"Error when parsing '{opt}'")
-
- opts[opt_name] = opt_value
-
- codegens.append(
- {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target":
is_tvm_target}
- )
-
- return codegens
-
-
-def is_inline_json(target):
- try:
- json.loads(target)
- return True
- except json.decoder.JSONDecodeError:
- return False
-
-
-def _combine_target_options(target, additional_target_options=None):
- if additional_target_options is None:
- return target
- if target["name"] in additional_target_options:
- target["opts"].update(additional_target_options[target["name"]])
- return target
-
-
-def _recombobulate_target(target):
- name = target["name"]
- opts = " ".join([f"-{key}={value}" for key, value in
target["opts"].items()])
- return f"{name} {opts}"
-
-
-def target_from_cli(target, additional_target_options=None):
- """
- Create a tvm.target.Target instance from a
- command line interface (CLI) string.
-
- Parameters
- ----------
- target : str
- compilation target as plain string,
- inline JSON or path to a JSON file
-
- additional_target_options: Optional[Dict[str, Dict[str,str]]]
- dictionary of additional target options to be
- combined with parsed targets
-
- Returns
- -------
- tvm.target.Target
- an instance of target device information
- extra_targets : list of dict
- This list preserves the order in which extra targets were
- provided via command line. Each Dict contains three keys:
- 'name', containing the name of the codegen; 'opts' containing
- a key-value for all options passed via CLI; 'raw',
- containing the plain string for this codegen
- """
- extra_targets = []
-
- if os.path.isfile(target):
- with open(target) as target_file:
- logger.debug("target input is a path: %s", target)
- target = "".join(target_file.readlines())
- elif is_inline_json(target):
- logger.debug("target input is inline JSON: %s", target)
- else:
- logger.debug("target input is plain text: %s", target)
- try:
- parsed_targets = parse_target(target)
- except ValueError as ex:
- raise TVMCException(f"Error parsing target string '{target}'.\nThe
error was: {ex}")
-
- validate_targets(parsed_targets, additional_target_options)
- tvm_targets = [
- _combine_target_options(t, additional_target_options)
- for t in parsed_targets
- if t["is_tvm_target"]
- ]
-
- # Validated target strings have 1 or 2 tvm targets, otherwise
- # `validate_targets` above will fail.
- if len(tvm_targets) == 1:
- target = _recombobulate_target(tvm_targets[0])
- target_host = None
- else:
- assert len(tvm_targets) == 2
- target = _recombobulate_target(tvm_targets[0])
- target_host = _recombobulate_target(tvm_targets[1])
-
- extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]]
-
- return tvm.target.Target(target, host=target_host), extra_targets
-
-
-def tracker_host_port_from_cli(rpc_tracker_str):
- """Extract hostname and (optional) port from strings
- like "1.2.3.4:9090" or "4.3.2.1".
-
- Used as a helper function to cover --rpc-tracker
- command line argument, in different subcommands.
-
- Parameters
- ----------
- rpc_tracker_str : str
- hostname (or IP address) and port of the RPC tracker,
- in the format 'hostname[:port]'.
-
- Returns
- -------
- rpc_hostname : str or None
- hostname or IP address, extracted from input.
- rpc_port : int or None
- port number extracted from input (9090 default).
- """
-
- rpc_hostname = rpc_port = None
-
- if rpc_tracker_str:
- parsed_url = urlparse("//%s" % rpc_tracker_str)
- rpc_hostname = parsed_url.hostname
- rpc_port = parsed_url.port or 9090
- logger.info("RPC tracker hostname: %s", rpc_hostname)
- logger.info("RPC tracker port: %s", rpc_port)
-
- return rpc_hostname, rpc_port
-
-
-def parse_pass_list_str(input_string):
- """Parse an input string for existing passes
-
- Parameters
- ----------
- input_string: str
- Possibly comma-separated string with the names of passes
-
- Returns
- -------
- list: a list of existing passes.
- """
- _prefix = "relay._transform."
- pass_list = input_string.split(",")
- missing_list = [
- p.strip()
- for p in pass_list
- if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(),
True) is None
- ]
- if len(missing_list) > 0:
- available_list = [
- n[len(_prefix) :] for n in registry.list_global_func_names() if
n.startswith(_prefix)
- ]
- raise argparse.ArgumentTypeError(
- "Following passes are not registered within tvm: {}. Available:
{}.".format(
- ", ".join(missing_list), ", ".join(sorted(available_list))
- )
- )
- return pass_list
-
-
-def parse_shape_string(inputs_string):
- """Parse an input shape dictionary string to a usable dictionary.
-
- Parameters
- ----------
- inputs_string: str
- A string of the form "input_name:[dim1,dim2,...,dimn]
input_name2:[dim1,dim2]" that
- indicates the desired shape for specific model inputs. Colons, forward
slashes and dots
- within input_names are supported. Spaces are supported inside of
dimension arrays.
-
- Returns
- -------
- shape_dict: dict
- A dictionary mapping input names to their shape for use in relay
frontend converters.
- """
-
- # Create a regex pattern that extracts each separate input mapping.
- # We want to be able to handle:
- # * Spaces inside arrays
- # * forward slashes inside names (but not at the beginning or end)
- # * colons inside names (but not at the beginning or end)
- # * dots inside names
- pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
- input_mappings = re.findall(pattern, inputs_string)
- if not input_mappings:
- raise argparse.ArgumentTypeError(
- "--input-shapes argument must be of the form "
- '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"'
- )
- shape_dict = {}
- for mapping in input_mappings:
- # Remove whitespace.
- mapping = mapping.replace(" ", "")
- # Split mapping into name and shape.
- name, shape_string = mapping.rsplit(":", 1)
- # Convert shape string into a list of integers or Anys if negative.
- shape = [int(x) if int(x) > 0 else relay.Any() for x in
shape_string.strip("][").split(",")]
- # Add parsed mapping to shape dictionary.
- shape_dict[name] = shape
-
- return shape_dict
-
-
-def get_pass_config_value(name, value, config_type):
- """Get a PassContext configuration value, based on its config data type.
-
- Parameters
- ----------
- name: str
- config identifier name.
- value: str
- value assigned to the config, provided via command line.
- config_type: str
- data type defined to the config, as string.
-
- Returns
- -------
- parsed_value: bool, int or str
- a representation of the input value, converted to the type
- specified by config_type.
- """
-
- if config_type == "IntImm":
- # "Bool" configurations in the PassContext are recognized as
- # IntImm, so deal with this case here
- mapping_values = {
- "false": False,
- "true": True,
- }
-
- if value.isdigit():
- parsed_value = int(value)
- else:
- # if not an int, accept only values on the mapping table, case
insensitive
- parsed_value = mapping_values.get(value.lower(), None)
-
- if parsed_value is None:
- raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
-
- if config_type == "runtime.String":
- parsed_value = value
-
- return parsed_value
-
-
-def parse_configs(input_configs):
- """Parse configuration values set via command line.
-
- Parameters
- ----------
- input_configs: list of str
- list of configurations provided via command line.
-
- Returns
- -------
- pass_context_configs: dict
- a dict containing key-value configs to be used in the PassContext.
- """
- if not input_configs:
- return {}
-
- all_configs = tvm.ir.transform.PassContext.list_configs()
- supported_config_types = ("IntImm", "runtime.String")
- supported_configs = [
- name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
- ]
-
- pass_context_configs = {}
-
- for config in input_configs:
- if not config:
- raise TVMCException(
- f"Invalid format for configuration '{config}', use
<config>=<value>"
- )
-
- # Each config is expected to be provided as "name=value"
- try:
- name, value = config.split("=")
- name = name.strip()
- value = value.strip()
- except ValueError:
- raise TVMCException(
- f"Invalid format for configuration '{config}', use
<config>=<value>"
- )
-
- if name not in all_configs:
- raise TVMCException(
- f"Configuration '{name}' is not defined in TVM. "
- f"These are the existing configurations: {',
'.join(all_configs)}"
- )
-
- if name not in supported_configs:
- raise TVMCException(
- f"Configuration '{name}' uses a data type not supported by
TVMC. "
- f"The following configurations are supported: {',
'.join(supported_configs)}"
- )
-
- parsed_value = get_pass_config_value(name, value,
all_configs[name]["type"])
- pass_context_configs[name] = parsed_value
-
- return pass_context_configs
-
-
-def get_project_options(project_info):
- """Get all project options as returned by Project API 'server_info_query'
- and return them in a dict indexed by the API method they belong to.
-
-
- Parameters
- ----------
- project_info: dict of list
- a dict of lists as returned by Project API 'server_info_query' among
- which there is a list called 'project_options' containing all the
- project options available for a given project/platform.
-
- Returns
- -------
- options_by_method: dict of list
- a dict indexed by the API method names (e.g. "generate_project",
- "build", "flash", or "open_transport") of lists containing all the
- options (plus associated metadata and formatted help text) that belong
- to a method.
-
- The metadata associated to the options include the field 'choices' and
- 'required' which are convenient for parsers.
-
- The formatted help text field 'help_text' is a string that contains the
- name of the option, the choices for the option, and the option's
default
- value.
- """
- options = project_info["project_options"]
-
- options_by_method = defaultdict(list)
- for opt in options:
- # Get list of methods associated with an option based on the
- # existance of a 'required' or 'optional' lists. API specification
- # guarantees at least one of these lists will exist. If a list does
- # not exist it's returned as None by the API.
- metadata = ["required", "optional"]
- om = [(opt[md], bool(md == "required")) for md in metadata if opt[md]]
- for methods, is_opt_required in om:
- for method in methods:
- name = opt["name"]
-
- # Only for boolean options set 'choices' accordingly to the
- # option type. API returns 'choices' associated to them
- # as None but 'choices' can be deduced from 'type' in this
case.
- if opt["type"] == "bool":
- opt["choices"] = ["true", "false"]
-
- if opt["choices"]:
- choices = "{" + ", ".join(opt["choices"]) + "}"
- else:
- choices = opt["name"].upper()
- option_choices_text = f"{name}={choices}"
-
- help_text = opt["help"][0].lower() + opt["help"][1:]
-
- if opt["default"]:
- default_text = f"Defaults to '{opt['default']}'."
- else:
- default_text = None
-
- formatted_help_text = format_option(
- option_choices_text, help_text, default_text,
is_opt_required
- )
-
- option = {
- "name": opt["name"],
- "choices": opt["choices"],
- "help_text": formatted_help_text,
- "required": is_opt_required,
- }
- options_by_method[method].append(option)
-
- return options_by_method
-
-
-def get_options(options):
- """Get option and option value from the list options returned by the
parser.
-
- Parameters
- ----------
- options: list of str
- list of strings of the form "option=value" as returned by the parser.
-
- Returns
- -------
- opts: dict
- dict indexed by option names and associated values.
- """
-
- opts = {}
- for option in options:
- try:
- k, v = option.split("=")
- opts[k] = v
- except ValueError:
- raise TVMCException(f"Invalid option format: {option}. Please use
OPTION=VALUE.")
-
- return opts
-
-
-def check_options(options, valid_options):
- """Check if an option (required or optional) is valid. i.e. in the list of
valid options.
-
- Parameters
- ----------
- options: dict
- dict indexed by option name of options and options values to be
checked.
-
- valid_options: list of dict
- list of all valid options and choices for a platform.
-
- Returns
- -------
- None. Raise TVMCException if check fails, i.e. if an option is not in the
list of valid options.
-
- """
- required_options = [opt["name"] for opt in valid_options if
opt["required"]]
- for required_option in required_options:
- if required_option not in options:
- raise TVMCException(
- f"Option '{required_option}' is required but was not
specified. Use --list-options "
- "to see all required options."
- )
-
- remaining_options = set(options) - set(required_options)
- optional_options = [opt["name"] for opt in valid_options if not
opt["required"]]
- for option in remaining_options:
- if option not in optional_options:
- raise TVMCException(
- f"Option '{option}' is invalid. Use --list-options to see all
available options."
- )
-
-
-def check_options_choices(options, valid_options):
- """Check if an option value is among the option's choices, when choices
exist.
-
- Parameters
- ----------
- options: dict
- dict indexed by option name of options and options values to be
checked.
-
- valid_options: list of dict
- list of all valid options and choices for a platform.
-
- Returns
- -------
- None. Raise TVMCException if check fails, i.e. if an option value is not
valid.
-
- """
- # Dict of all valid options and associated valid choices.
- # Options with no choices are excluded from the dict.
- valid_options_choices = {
- opt["name"]: opt["choices"] for opt in valid_options if opt["choices"]
is not None
- }
-
- for option in options:
- if option in valid_options_choices:
- if options[option] not in valid_options_choices[option]:
- raise TVMCException(
- f"Choice '{options[option]}' for option '{option}' is
invalid. "
- "Use --list-options to see all available choices for that
option."
- )
-
-
-def get_and_check_options(passed_options, valid_options):
- """Get options and check if they are valid. If choices exist for them,
check values against it.
-
- Parameters
- ----------
- passed_options: list of str
- list of strings in the "key=value" form as captured by argparse.
-
- valid_option: list
- list with all options available for a given API method / project as
returned by
- get_project_options().
-
- Returns
- -------
- opts: dict
- dict indexed by option names and associated values.
-
- Or None if passed_options is None.
-
- """
-
- if passed_options is None:
- # No options to check
- return None
-
- # From a list of k=v strings, make a dict options[k]=v
- opts = get_options(passed_options)
- # Check if passed options are valid
- check_options(opts, valid_options)
- # Check (when a list of choices exists) if the passed values are valid
- check_options_choices(opts, valid_options)
-
- return opts
-
-
-def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str:
- """Get project directory path"""
- if not os.path.isabs(project_dir):
- return os.path.abspath(project_dir)
- return project_dir
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index dbf7e46..d260c98 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -29,11 +29,14 @@ from tvm.driver.tvmc.registry import
generate_registry_args, reconstruct_registr
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
-from . import common, composite_target, frontends
+from . import composite_target, frontends
from .model import TVMCModel, TVMCPackage
from .main import register_parser
-from .target import generate_target_args, reconstruct_target_args
-
+from .target import target_from_cli, generate_target_args,
reconstruct_target_args
+from .pass_config import parse_configs
+from .pass_list import parse_pass_list_str
+from .transform import convert_graph_layout
+from .shape_parser import parse_shape_string
# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")
@@ -124,13 +127,13 @@ def add_compile_parser(subparsers, _):
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]".',
- type=common.parse_shape_string,
+ type=parse_shape_string,
default=None,
)
parser.add_argument(
"--disabled-pass",
help="disable specific passes, comma-separated list of pass names.",
- type=common.parse_pass_list_str,
+ type=parse_pass_list_str,
default="",
)
@@ -249,12 +252,12 @@ def compile_model(
"""
mod, params = tvmc_model.mod, tvmc_model.params
- config = common.parse_configs(pass_context_configs)
+ config = parse_configs(pass_context_configs)
if desired_layout:
- mod = common.convert_graph_layout(mod, desired_layout)
+ mod = convert_graph_layout(mod, desired_layout)
- tvm_target, extra_targets = common.target_from_cli(target,
additional_target_options)
+ tvm_target, extra_targets = target_from_cli(target,
additional_target_options)
tvm_target, target_host = Target.check_and_update_host_consist(tvm_target,
target_host)
for codegen_from_cli in extra_targets:
diff --git a/python/tvm/driver/tvmc/composite_target.py
b/python/tvm/driver/tvmc/composite_target.py
index 848af1e..f347158 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -31,7 +31,7 @@ from tvm.relay.op.contrib.bnns import partition_for_bnns
from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai
-from .common import TVMCException
+from tvm.driver.tvmc import TVMCException
# pylint: disable=invalid-name
diff --git a/python/tvm/driver/tvmc/frontends.py
b/python/tvm/driver/tvmc/frontends.py
index b6773dc..a322278 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -31,8 +31,7 @@ from pathlib import Path
import numpy as np
from tvm import relay
-from tvm.driver.tvmc.common import TVMCException
-from tvm.driver.tvmc.common import TVMCImportError
+from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel
diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py
index 3fb8cd7..b74cc7d 100644
--- a/python/tvm/driver/tvmc/main.py
+++ b/python/tvm/driver/tvmc/main.py
@@ -25,8 +25,7 @@ import sys
import tvm
-from tvm.driver.tvmc.common import TVMCException
-from tvm.driver.tvmc.common import TVMCImportError
+from tvm.driver.tvmc import TVMCException, TVMCImportError
REGISTERED_PARSER = []
diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py
index a9c17b8..4f478c7 100644
--- a/python/tvm/driver/tvmc/micro.py
+++ b/python/tvm/driver/tvmc/micro.py
@@ -23,10 +23,10 @@ from pathlib import Path
import shutil
import sys
+from . import TVMCException
from .main import register_parser
-from .common import (
- TVMCException,
- TVMCSuppressedArgumentParser,
+from .arguments import TVMCSuppressedArgumentParser
+from .project import (
get_project_options,
get_and_check_options,
get_project_dir,
diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py
index 5110aed..9a2617f 100644
--- a/python/tvm/driver/tvmc/model.py
+++ b/python/tvm/driver/tvmc/model.py
@@ -54,6 +54,7 @@ import tvm
import tvm.contrib.cc
from tvm import relay
from tvm.contrib import utils
+from tvm.driver.tvmc import TVMCException
from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule
from tvm.runtime.module import BenchmarkResult
@@ -62,8 +63,6 @@ try:
except ImportError:
export_model_library_format = None
-from .common import TVMCException
-
class TVMCModel(object):
"""Initialize a TVMC model from a relay model definition or a saved file.
diff --git a/python/tvm/driver/tvmc/pass_config.py
b/python/tvm/driver/tvmc/pass_config.py
new file mode 100644
index 0000000..7cf0f01
--- /dev/null
+++ b/python/tvm/driver/tvmc/pass_config.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TVMC PassContext Interface
+"""
+
+import tvm
+from tvm.driver.tvmc import TVMCException
+
+
+def get_pass_config_value(name, value, config_type):
+ """Get a PassContext configuration value, based on its config data type.
+
+ Parameters
+ ----------
+ name: str
+ config identifier name.
+ value: str
+ value assigned to the config, provided via command line.
+ config_type: str
+ data type defined to the config, as string.
+
+ Returns
+ -------
+ parsed_value: bool, int or str
+ a representation of the input value, converted to the type
+ specified by config_type.
+ """
+
+ if config_type == "IntImm":
+ # "Bool" configurations in the PassContext are recognized as
+ # IntImm, so deal with this case here
+ mapping_values = {
+ "false": False,
+ "true": True,
+ }
+
+ if value.isdigit():
+ parsed_value = int(value)
+ else:
+ # if not an int, accept only values on the mapping table, case
insensitive
+ parsed_value = mapping_values.get(value.lower(), None)
+
+ if parsed_value is None:
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+
+ if config_type == "runtime.String":
+ parsed_value = value
+
+ return parsed_value
+
+
+def parse_configs(input_configs):
+ """Parse configuration values set via command line.
+
+ Parameters
+ ----------
+ input_configs: list of str
+ list of configurations provided via command line.
+
+ Returns
+ -------
+ pass_context_configs: dict
+ a dict containing key-value configs to be used in the PassContext.
+ """
+ if not input_configs:
+ return {}
+
+ all_configs = tvm.ir.transform.PassContext.list_configs()
+ supported_config_types = ("IntImm", "runtime.String")
+ supported_configs = [
+ name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
+ ]
+
+ pass_context_configs = {}
+
+ for config in input_configs:
+ if not config:
+ raise TVMCException(
+ f"Invalid format for configuration '{config}', use
<config>=<value>"
+ )
+
+ # Each config is expected to be provided as "name=value"
+ try:
+ name, value = config.split("=")
+ name = name.strip()
+ value = value.strip()
+ except ValueError:
+ raise TVMCException(
+ f"Invalid format for configuration '{config}', use
<config>=<value>"
+ )
+
+ if name not in all_configs:
+ raise TVMCException(
+ f"Configuration '{name}' is not defined in TVM. "
+ f"These are the existing configurations: {',
'.join(all_configs)}"
+ )
+
+ if name not in supported_configs:
+ raise TVMCException(
+ f"Configuration '{name}' uses a data type not supported by
TVMC. "
+ f"The following configurations are supported: {',
'.join(supported_configs)}"
+ )
+
+ parsed_value = get_pass_config_value(name, value,
all_configs[name]["type"])
+ pass_context_configs[name] = parsed_value
+
+ return pass_context_configs
diff --git a/python/tvm/driver/tvmc/pass_list.py
b/python/tvm/driver/tvmc/pass_list.py
new file mode 100644
index 0000000..09ec6aa
--- /dev/null
+++ b/python/tvm/driver/tvmc/pass_list.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language
+"""
+TVMC Pass List Management
+"""
+
+import argparse
+
+import tvm
+from tvm._ffi import registry
+
+
+def parse_pass_list_str(input_string):
+ """Parse an input string for existing passes
+
+ Parameters
+ ----------
+ input_string: str
+ Possibly comma-separated string with the names of passes
+
+ Returns
+ -------
+ list: a list of existing passes.
+ """
+ _prefix = "relay._transform."
+ pass_list = input_string.split(",")
+ missing_list = [
+ p.strip()
+ for p in pass_list
+ if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(),
True) is None
+ ]
+ if len(missing_list) > 0:
+ available_list = [
+ n[len(_prefix) :] for n in registry.list_global_func_names() if
n.startswith(_prefix)
+ ]
+ raise argparse.ArgumentTypeError(
+ "Following passes are not registered within tvm: {}. Available:
{}.".format(
+ ", ".join(missing_list), ", ".join(sorted(available_list))
+ )
+ )
+ return pass_list
diff --git a/python/tvm/driver/tvmc/project.py
b/python/tvm/driver/tvmc/project.py
new file mode 100644
index 0000000..d9b22a2
--- /dev/null
+++ b/python/tvm/driver/tvmc/project.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TVMC Project Generation Functions
+"""
+
+import os
+import pathlib
+from collections import defaultdict
+from typing import Union
+
+from . import TVMCException
+from .fmtopt import format_option
+
+
+def get_project_options(project_info):
+ """Get all project options as returned by Project API 'server_info_query'
+ and return them in a dict indexed by the API method they belong to.
+
+
+ Parameters
+ ----------
+ project_info: dict of list
+ a dict of lists as returned by Project API 'server_info_query' among
+ which there is a list called 'project_options' containing all the
+ project options available for a given project/platform.
+
+ Returns
+ -------
+ options_by_method: dict of list
+ a dict indexed by the API method names (e.g. "generate_project",
+ "build", "flash", or "open_transport") of lists containing all the
+ options (plus associated metadata and formatted help text) that belong
+ to a method.
+
+ The metadata associated to the options include the field 'choices' and
+ 'required' which are convenient for parsers.
+
+ The formatted help text field 'help_text' is a string that contains the
+ name of the option, the choices for the option, and the option's
default
+ value.
+ """
+ options = project_info["project_options"]
+
+ options_by_method = defaultdict(list)
+ for opt in options:
+ # Get list of methods associated with an option based on the
+ # existance of a 'required' or 'optional' lists. API specification
+ # guarantees at least one of these lists will exist. If a list does
+ # not exist it's returned as None by the API.
+ metadata = ["required", "optional"]
+ option_methods = [(opt[md], bool(md == "required")) for md in metadata
if opt[md]]
+ for methods, is_opt_required in option_methods:
+ for method in methods:
+ name = opt["name"]
+
+ # Only for boolean options set 'choices' accordingly to the
+ # option type. API returns 'choices' associated to them
+ # as None but 'choices' can be deduced from 'type' in this
case.
+ if opt["type"] == "bool":
+ opt["choices"] = ["true", "false"]
+
+ if opt["choices"]:
+ choices = "{" + ", ".join(opt["choices"]) + "}"
+ else:
+ choices = opt["name"].upper()
+ option_choices_text = f"{name}={choices}"
+
+ help_text = opt["help"][0].lower() + opt["help"][1:]
+
+ if opt["default"]:
+ default_text = f"Defaults to '{opt['default']}'."
+ else:
+ default_text = None
+
+ formatted_help_text = format_option(
+ option_choices_text, help_text, default_text,
is_opt_required
+ )
+
+ option = {
+ "name": opt["name"],
+ "choices": opt["choices"],
+ "help_text": formatted_help_text,
+ "required": is_opt_required,
+ }
+ options_by_method[method].append(option)
+
+ return options_by_method
+
+
+def get_options(options):
+ """Get option and option value from the list options returned by the
parser.
+
+ Parameters
+ ----------
+ options: list of str
+ list of strings of the form "option=value" as returned by the parser.
+
+ Returns
+ -------
+ opts: dict
+ dict indexed by option names and associated values.
+ """
+
+ opts = {}
+ for option in options:
+ try:
+ k, v = option.split("=")
+ opts[k] = v
+ except ValueError:
+ raise TVMCException(f"Invalid option format: {option}. Please use
OPTION=VALUE.")
+
+ return opts
+
+
+def check_options(options, valid_options):
+ """Check if an option (required or optional) is valid. i.e. in the list of
valid options.
+
+ Parameters
+ ----------
+ options: dict
+ dict indexed by option name of options and options values to be
checked.
+
+ valid_options: list of dict
+ list of all valid options and choices for a platform.
+
+ Returns
+ -------
+ None. Raise TVMCException if check fails, i.e. if an option is not in the
list of valid options.
+
+ """
+ required_options = [opt["name"] for opt in valid_options if
opt["required"]]
+ for required_option in required_options:
+ if required_option not in options:
+ raise TVMCException(
+ f"Option '{required_option}' is required but was not
specified. Use --list-options "
+ "to see all required options."
+ )
+
+ remaining_options = set(options) - set(required_options)
+ optional_options = [opt["name"] for opt in valid_options if not
opt["required"]]
+ for option in remaining_options:
+ if option not in optional_options:
+ raise TVMCException(
+ f"Option '{option}' is invalid. Use --list-options to see all
available options."
+ )
+
+
+def check_options_choices(options, valid_options):
+ """Check if an option value is among the option's choices, when choices
exist.
+
+ Parameters
+ ----------
+ options: dict
+ dict indexed by option name of options and options values to be
checked.
+
+ valid_options: list of dict
+ list of all valid options and choices for a platform.
+
+ Returns
+ -------
+ None. Raise TVMCException if check fails, i.e. if an option value is not
valid.
+
+ """
+ # Dict of all valid options and associated valid choices.
+ # Options with no choices are excluded from the dict.
+ valid_options_choices = {
+ opt["name"]: opt["choices"] for opt in valid_options if opt["choices"]
is not None
+ }
+
+ for option in options:
+ if option in valid_options_choices:
+ if options[option] not in valid_options_choices[option]:
+ raise TVMCException(
+ f"Choice '{options[option]}' for option '{option}' is
invalid. "
+ "Use --list-options to see all available choices for that
option."
+ )
+
+
+def get_and_check_options(passed_options, valid_options):
+ """Get options and check if they are valid. If choices exist for them,
check values against it.
+
+ Parameters
+ ----------
+ passed_options: list of str
+ list of strings in the "key=value" form as captured by argparse.
+
+ valid_option: list
+ list with all options available for a given API method / project as
returned by
+ get_project_options().
+
+ Returns
+ -------
+ opts: dict
+ dict indexed by option names and associated values.
+
+ Or None if passed_options is None.
+
+ """
+
+ if passed_options is None:
+ # No options to check
+ return None
+
+ # From a list of k=v strings, make a dict options[k]=v
+ opts = get_options(passed_options)
+ # Check if passed options are valid
+ check_options(opts, valid_options)
+ # Check (when a list of choices exists) if the passed values are valid
+ check_options_choices(opts, valid_options)
+
+ return opts
+
+
+def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str:
+ """Get project directory path"""
+ if not os.path.isabs(project_dir):
+ return os.path.abspath(project_dir)
+ return project_dir
diff --git a/python/tvm/driver/tvmc/registry.py
b/python/tvm/driver/tvmc/registry.py
index 384a3bd..334aa1b 100644
--- a/python/tvm/driver/tvmc/registry.py
+++ b/python/tvm/driver/tvmc/registry.py
@@ -18,7 +18,7 @@
This file contains functions for processing registry based inputs for the TVMC
CLI
"""
-from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc import TVMCException
# We can't tell the type inside an Array but all current options are strings so
# it can default to that. Bool is used alongside Integer but aren't
distinguished
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index fd342a5..a234396 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -33,17 +33,18 @@ from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_executor as runtime
from tvm.contrib.debugger import debug_executor
from tvm.relay.param_dict import load_param_dict
-from . import common
-from .common import (
- TVMCException,
- TVMCSuppressedArgumentParser,
+from . import TVMCException
+from .arguments import TVMCSuppressedArgumentParser
+from .project import (
get_project_options,
get_and_check_options,
get_project_dir,
)
+
from .main import register_parser
from .model import TVMCPackage, TVMCResult
from .result_utils import get_top_results
+from .tracker import tracker_host_port_from_cli
try:
import tvm.micro.project as project
@@ -245,7 +246,7 @@ def drive_run(args):
except ReadError:
raise TVMCException(f"Could not read model from archive {path}!")
- rpc_hostname, rpc_port =
common.tracker_host_port_from_cli(args.rpc_tracker)
+ rpc_hostname, rpc_port = tracker_host_port_from_cli(args.rpc_tracker)
try:
inputs = np.load(args.inputs) if args.inputs else {}
diff --git a/python/tvm/driver/tvmc/shape_parser.py
b/python/tvm/driver/tvmc/shape_parser.py
new file mode 100644
index 0000000..24b7727
--- /dev/null
+++ b/python/tvm/driver/tvmc/shape_parser.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TVMC Shape Parsing
+"""
+
+import argparse
+import re
+
+from tvm import relay
+
+
+def parse_shape_string(inputs_string):
+ """Parse an input shape dictionary string to a usable dictionary.
+
+ Parameters
+ ----------
+ inputs_string: str
+ A string of the form "input_name:[dim1,dim2,...,dimn]
input_name2:[dim1,dim2]" that
+ indicates the desired shape for specific model inputs. Colons, forward
slashes and dots
+ within input_names are supported. Spaces are supported inside of
dimension arrays.
+
+ Returns
+ -------
+ shape_dict: dict
+ A dictionary mapping input names to their shape for use in relay
frontend converters.
+ """
+
+ # Create a regex pattern that extracts each separate input mapping.
+ # We want to be able to handle:
+ # * Spaces inside arrays
+ # * forward slashes inside names (but not at the beginning or end)
+ # * colons inside names (but not at the beginning or end)
+ # * dots inside names
+ pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
+ input_mappings = re.findall(pattern, inputs_string)
+ if not input_mappings:
+ raise argparse.ArgumentTypeError(
+ "--input-shapes argument must be of the form "
+ '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"'
+ )
+ shape_dict = {}
+ for mapping in input_mappings:
+ # Remove whitespace.
+ mapping = mapping.replace(" ", "")
+ # Split mapping into name and shape.
+ name, shape_string = mapping.rsplit(":", 1)
+ # Convert shape string into a list of integers or Anys if negative.
+ shape = [int(x) if int(x) > 0 else relay.Any() for x in
shape_string.strip("][").split(",")]
+ # Add parsed mapping to shape dictionary.
+ shape_dict[name] = shape
+
+ return shape_dict
diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py
index 067a361..15ed19d 100644
--- a/python/tvm/driver/tvmc/target.py
+++ b/python/tvm/driver/tvmc/target.py
@@ -18,9 +18,19 @@
This file contains functions for processing target inputs for the TVMC CLI
"""
+import os
+import logging
+import json
+import re
+
+import tvm
from tvm.driver import tvmc
+from tvm.driver.tvmc import TVMCException
from tvm.target import Target, TargetKind
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
# We can't tell the type inside an Array but all current options are strings so
# it can default to that. Bool is used alongside Integer but aren't
distinguished
# between as both are represented by IntImm
@@ -74,3 +84,271 @@ def reconstruct_target_args(args):
if kind_options:
reconstructed[target_kind] = kind_options
return reconstructed
+
+
+def validate_targets(parse_targets, additional_target_options=None):
+ """
+ Apply a series of validations in the targets provided via CLI.
+ """
+ tvm_target_kinds = tvm.target.Target.list_kinds()
+ targets = [t["name"] for t in parse_targets]
+
+ if len(targets) > len(set(targets)):
+ raise TVMCException("Duplicate target definitions are not allowed")
+
+ if targets[-1] not in tvm_target_kinds:
+ tvm_target_names = ", ".join(tvm_target_kinds)
+ raise TVMCException(
+ f"The last target needs to be a TVM target. Choices:
{tvm_target_names}"
+ )
+
+ tvm_targets = [t for t in targets if t in tvm_target_kinds]
+ if len(tvm_targets) > 2:
+ verbose_tvm_targets = ", ".join(tvm_targets)
+ raise TVMCException(
+ "Only two of the following targets can be used at a time. "
+ f"Found: {verbose_tvm_targets}."
+ )
+
+ if additional_target_options is not None:
+ for target_name in additional_target_options:
+ if not any([target for target in parse_targets if target["name"]
== target_name]):
+ first_option =
list(additional_target_options[target_name].keys())[0]
+ raise TVMCException(
+ f"Passed --target-{target_name}-{first_option}"
+ f" but did not specify {target_name} target"
+ )
+
+
+def tokenize_target(target):
+ """
+ Extract a list of tokens from a target specification text.
+
+ It covers some corner-cases that are not covered by the built-in
+ module 'shlex', such as the use of "+" as a punctuation character.
+
+
+ Example
+ -------
+
+ For the input `foo -op1=v1 -op2="v ,2", bar -op3=v-4` we
+ should obtain:
+
+ ["foo", "-op1=v1", "-op2="v ,2"", ",", "bar", "-op3=v-4"]
+
+ Parameters
+ ----------
+ target : str
+ Target options sent via CLI arguments
+
+ Returns
+ -------
+ list of str
+ a list of parsed tokens extracted from the target string
+ """
+
+ # Regex to tokenize the "--target" value. It is split into five parts
+ # to match with:
+ # 1. target and option names e.g. llvm, -mattr=, -mcpu=
+ # 2. option values, all together, without quotes e.g. -mattr=+foo,+opt
+ # 3. option values, when single quotes are used e.g. -mattr='+foo, +opt'
+ # 4. option values, when double quotes are used e.g. -mattr="+foo ,+opt"
+ # 5. commas that separate different targets e.g. "my-target, llvm"
+ target_pattern = (
+ r"(\-{0,2}[\w\-]+\=?"
+ r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*"
+ r"|[\'][\w\+\-,\s\.]+[\']"
+ r"|[\"][\w\+\-,\s\.]+[\"])*"
+ r"|,)"
+ )
+
+ return re.findall(target_pattern, target)
+
+
+def parse_target(target):
+ """
+ Parse a plain string of targets provided via a command-line
+ argument.
+
+ To send more than one codegen, a comma-separated list
+ is expected. Options start with -<option_name>=<value>.
+
+ We use python standard library 'shlex' to parse the argument in
+ a POSIX compatible way, so that if options are defined as
+ strings with spaces or commas, for example, this is considered
+ and parsed accordingly.
+
+
+ Example
+ -------
+
+ For the input `--target="foo -op1=v1 -op2="v ,2", bar -op3=v-4"` we
+ should obtain:
+
+ [
+ {
+ name: "foo",
+ opts: {"op1":"v1", "op2":"v ,2"},
+ raw: 'foo -op1=v1 -op2="v ,2"'
+ },
+ {
+ name: "bar",
+ opts: {"op3":"v-4"},
+ raw: 'bar -op3=v-4'
+ }
+ ]
+
+ Parameters
+ ----------
+ target : str
+ Target options sent via CLI arguments
+
+ Returns
+ -------
+ codegens : list of dict
+ This list preserves the order in which codegens were
+ provided via command line. Each Dict contains three keys:
+ 'name', containing the name of the codegen; 'opts' containing
+ a key-value for all options passed via CLI; 'raw',
+ containing the plain string for this codegen
+ """
+ codegen_names = tvmc.composite_target.get_codegen_names()
+ codegens = []
+
+ tvm_target_kinds = tvm.target.Target.list_kinds()
+ parsed_tokens = tokenize_target(target)
+
+ split_codegens = []
+ current_codegen = []
+ split_codegens.append(current_codegen)
+ for token in parsed_tokens:
+ # every time there is a comma separating
+ # two codegen definitions, prepare for
+ # a new codegen
+ if token == ",":
+ current_codegen = []
+ split_codegens.append(current_codegen)
+ else:
+ # collect a new token for the current
+ # codegen being parsed
+ current_codegen.append(token)
+
+ # at this point we have a list of lists,
+ # each item on the first list is a codegen definition
+ # in the comma-separated values
+ for codegen_def in split_codegens:
+ # the first is expected to be the name
+ name = codegen_def[0]
+ is_tvm_target = name in tvm_target_kinds and name not in codegen_names
+ raw_target = " ".join(codegen_def)
+ all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
+ opts = {}
+ for opt in all_opts:
+ try:
+ # deal with -- prefixed flags
+ if opt.startswith("--"):
+ opt_name = opt[2:]
+ opt_value = True
+ else:
+ opt = opt[1:] if opt.startswith("-") else opt
+ opt_name, opt_value = opt.split("=", maxsplit=1)
+
+ # remove quotes from the value: quotes are only parsed if
they match,
+ # so it is safe to assume that if the string starts with
quote, it ends
+ # with quote.
+ opt_value = opt_value[1:-1] if opt_value[0] in ('"', "'")
else opt_value
+ except ValueError:
+ raise ValueError(f"Error when parsing '{opt}'")
+
+ opts[opt_name] = opt_value
+
+ codegens.append(
+ {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target":
is_tvm_target}
+ )
+
+ return codegens
+
+
+def is_inline_json(target):
+ try:
+ json.loads(target)
+ return True
+ except json.decoder.JSONDecodeError:
+ return False
+
+
+def _combine_target_options(target, additional_target_options=None):
+ if additional_target_options is None:
+ return target
+ if target["name"] in additional_target_options:
+ target["opts"].update(additional_target_options[target["name"]])
+ return target
+
+
+def _recombobulate_target(target):
+ name = target["name"]
+ opts = " ".join([f"-{key}={value}" for key, value in
target["opts"].items()])
+ return f"{name} {opts}"
+
+
+def target_from_cli(target, additional_target_options=None):
+ """
+ Create a tvm.target.Target instance from a
+ command line interface (CLI) string.
+
+ Parameters
+ ----------
+ target : str
+ compilation target as plain string,
+ inline JSON or path to a JSON file
+
+ additional_target_options: Optional[Dict[str, Dict[str,str]]]
+ dictionary of additional target options to be
+ combined with parsed targets
+
+ Returns
+ -------
+ tvm.target.Target
+ an instance of target device information
+ extra_targets : list of dict
+ This list preserves the order in which extra targets were
+ provided via command line. Each Dict contains three keys:
+ 'name', containing the name of the codegen; 'opts' containing
+ a key-value for all options passed via CLI; 'raw',
+ containing the plain string for this codegen
+ """
+ extra_targets = []
+
+ if os.path.isfile(target):
+ with open(target) as target_file:
+ logger.debug("target input is a path: %s", target)
+ target = "".join(target_file.readlines())
+ elif is_inline_json(target):
+ logger.debug("target input is inline JSON: %s", target)
+ else:
+ logger.debug("target input is plain text: %s", target)
+ try:
+ parsed_targets = parse_target(target)
+ except ValueError as error:
+ raise TVMCException(f"Error parsing target string '{target}'.\nThe
error was: {error}")
+
+ validate_targets(parsed_targets, additional_target_options)
+ tvm_targets = [
+ _combine_target_options(t, additional_target_options)
+ for t in parsed_targets
+ if t["is_tvm_target"]
+ ]
+
+ # Validated target strings have 1 or 2 tvm targets, otherwise
+ # `validate_targets` above will fail.
+ if len(tvm_targets) == 1:
+ target = _recombobulate_target(tvm_targets[0])
+ target_host = None
+ else:
+ assert len(tvm_targets) == 2
+ target = _recombobulate_target(tvm_targets[0])
+ target_host = _recombobulate_target(tvm_targets[1])
+
+ extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]]
+
+ return tvm.target.Target(target, host=target_host), extra_targets
diff --git a/python/tvm/driver/tvmc/tracker.py
b/python/tvm/driver/tvmc/tracker.py
new file mode 100644
index 0000000..65fda42
--- /dev/null
+++ b/python/tvm/driver/tvmc/tracker.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language
+"""
+TVMC Remote Tracker
+"""
+
+import logging
+from urllib.parse import urlparse
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+
+def tracker_host_port_from_cli(rpc_tracker_str):
+ """Extract hostname and (optional) port from strings
+ like "1.2.3.4:9090" or "4.3.2.1".
+
+ Used as a helper function to cover --rpc-tracker
+ command line argument, in different subcommands.
+
+ Parameters
+ ----------
+ rpc_tracker_str : str
+ hostname (or IP address) and port of the RPC tracker,
+ in the format 'hostname[:port]'.
+
+ Returns
+ -------
+ rpc_hostname : str or None
+ hostname or IP address, extracted from input.
+ rpc_port : int or None
+ port number extracted from input (9090 default).
+ """
+
+ rpc_hostname = rpc_port = None
+
+ if rpc_tracker_str:
+ parsed_url = urlparse("//%s" % rpc_tracker_str)
+ rpc_hostname = parsed_url.hostname
+ rpc_port = parsed_url.port or 9090
+ logger.info("RPC tracker hostname: %s", rpc_hostname)
+ logger.info("RPC tracker port: %s", rpc_port)
+
+ return rpc_hostname, rpc_port
diff --git a/python/tvm/driver/tvmc/transform.py
b/python/tvm/driver/tvmc/transform.py
new file mode 100644
index 0000000..3f77765
--- /dev/null
+++ b/python/tvm/driver/tvmc/transform.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language
+"""
+TVMC Graph Transforms
+"""
+
+from tvm import relay, transform
+from tvm.driver.tvmc import TVMCException
+
+
+def convert_graph_layout(mod, desired_layout):
+ """Alter the layout of the input graph.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module to convert.
+ desired_layout : str
+ The layout to convert to.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ The converted module.
+ """
+
+ # Assume for the time being that graphs only have
+ # conv2d as heavily-sensitive operators.
+ desired_layouts = {
+ "nn.conv2d": [desired_layout, "default"],
+ "nn.conv2d_transpose": [desired_layout, "default"],
+ "qnn.conv2d": [desired_layout, "default"],
+ }
+
+ # Convert the layout of the graph where possible.
+ seq = transform.Sequential(
+ [
+ relay.transform.RemoveUnusedFunctions(),
+ relay.transform.ConvertLayout(desired_layouts),
+ ]
+ )
+
+ with transform.PassContext(opt_level=3):
+ try:
+ return seq(mod)
+ except Exception as err:
+ raise TVMCException(
+ "Error converting layout to {0}: {1}".format(desired_layout,
str(err))
+ )
diff --git a/tests/python/driver/tvmc/test_autotuner.py
b/tests/python/driver/tvmc/test_autotuner.py
index 52c3317..a1915a0 100644
--- a/tests/python/driver/tvmc/test_autotuner.py
+++ b/tests/python/driver/tvmc/test_autotuner.py
@@ -153,7 +153,7 @@ def test_tune_tasks__invalid_tuner(onnx_mnist,
tmpdir_factory):
tasks = _get_tasks(onnx_mnist)
log_file = os.path.join(tmpdir_factory.mktemp("data"), "log2.txt")
- with pytest.raises(tvmc.common.TVMCException):
+ with pytest.raises(tvmc.TVMCException):
tvmc.autotuner.tune_tasks(tasks, log_file, _get_measure_options(),
"invalid_tuner", 1, 1)
diff --git a/tests/python/driver/tvmc/test_compiler.py
b/tests/python/driver/tvmc/test_compiler.py
index 73f3a0f..5ebcb8e 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -70,7 +70,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
verify_compile_tflite_module(tflite_mobilenet_v1_1_quant)
# Check with manual shape override
shape_string = "input:[1,224,224,3]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict)
@@ -218,7 +218,7 @@ def test_compile_onnx_module(onnx_resnet50):
verify_compile_onnx_module(onnx_resnet50)
# Test with manual shape dict
shape_string = "data:[1,3,200,200]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
verify_compile_onnx_module(onnx_resnet50, shape_dict)
@@ -296,7 +296,7 @@ def test_compile_paddle_module(paddle_resnet50):
verify_compile_paddle_module(paddle_resnet50)
# Check with manual shape override
shape_string = "inputs:[1,3,224,224]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
verify_compile_paddle_module(paddle_resnet50, shape_dict)
diff --git a/tests/python/driver/tvmc/test_composite_target.py
b/tests/python/driver/tvmc/test_composite_target.py
index 80b4d1b..dfaf30c 100644
--- a/tests/python/driver/tvmc/test_composite_target.py
+++ b/tests/python/driver/tvmc/test_composite_target.py
@@ -27,7 +27,7 @@ import tvm
from tvm.driver import tvmc
-from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc import TVMCException
def test_get_codegen_names():
diff --git a/tests/python/driver/tvmc/test_frontends.py
b/tests/python/driver/tvmc/test_frontends.py
index e887857..b760669 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -24,8 +24,7 @@ from unittest import mock
from tvm.ir.module import IRModule
from tvm.driver import tvmc
-from tvm.driver.tvmc.common import TVMCException
-from tvm.driver.tvmc.common import TVMCImportError
+from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel
@@ -268,7 +267,7 @@ def
test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant):
before = tvmc_model.mod
expected_layout = "NCHW"
- after = tvmc.common.convert_graph_layout(before, expected_layout)
+ after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
@@ -293,7 +292,7 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50):
before = tvmc_model.mod
expected_layout = "NHWC"
- after = tvmc.common.convert_graph_layout(before, expected_layout)
+ after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
@@ -318,7 +317,7 @@ def
test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50):
before = tvmc_model.mod
expected_layout = "NHWC"
- after = tvmc.common.convert_graph_layout(before, expected_layout)
+ after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
@@ -343,7 +342,7 @@ def
test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_
before = tvmc_model.mod
expected_layout = "NHWC"
- after = tvmc.common.convert_graph_layout(before, expected_layout)
+ after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
@@ -368,7 +367,7 @@ def
test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50):
before = tvmc_model.mod
expected_layout = "NCHW"
- after = tvmc.common.convert_graph_layout(before, expected_layout)
+ after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
diff --git a/tests/python/driver/tvmc/test_pass_config.py
b/tests/python/driver/tvmc/test_pass_config.py
index d8ffd7d..bb815e1 100644
--- a/tests/python/driver/tvmc/test_pass_config.py
+++ b/tests/python/driver/tvmc/test_pass_config.py
@@ -18,33 +18,33 @@
import pytest
from tvm.contrib.target.vitis_ai import vitis_ai_available
-from tvm.driver import tvmc
-from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc import TVMCException
+from tvm.driver.tvmc.pass_config import parse_configs
def test_config_invalid_format():
with pytest.raises(TVMCException):
- _ =
tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"])
+ _ = parse_configs(["relay.backend.use_auto_scheduler.missing.value"])
def test_config_missing_from_tvm():
with pytest.raises(TVMCException):
- _ =
tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])
+ _ =
parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])
def test_config_unsupported_tvmc_config():
with pytest.raises(TVMCException):
- _ = tvmc.common.parse_configs(["tir.LoopPartition=value"])
+ _ = parse_configs(["tir.LoopPartition=value"])
def test_config_empty():
with pytest.raises(TVMCException):
- _ = tvmc.common.parse_configs([""])
+ _ = parse_configs([""])
def test_config_valid_config_bool():
- configs =
tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"])
+ configs = parse_configs(["relay.backend.use_auto_scheduler=true"])
assert len(configs) == 1
assert "relay.backend.use_auto_scheduler" in configs.keys()
@@ -56,7 +56,7 @@ def test_config_valid_config_bool():
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI
OFF'",
)
def test_config_valid_multiple_configs():
- configs = tvmc.common.parse_configs(
+ configs = parse_configs(
[
"relay.backend.use_auto_scheduler=false",
"tir.detect_global_barrier=10",
diff --git a/tests/python/driver/tvmc/test_pass_list.py
b/tests/python/driver/tvmc/test_pass_list.py
index de50b04..f43da63 100644
--- a/tests/python/driver/tvmc/test_pass_list.py
+++ b/tests/python/driver/tvmc/test_pass_list.py
@@ -17,15 +17,15 @@
import argparse
import pytest
-from tvm.driver import tvmc
+from tvm.driver.tvmc.pass_list import parse_pass_list_str
def test_parse_pass_list_str():
- assert [""] == tvmc.common.parse_pass_list_str("")
- assert ["FoldScaleAxis", "FuseOps"] ==
tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps")
+ assert [""] == parse_pass_list_str("")
+ assert ["FoldScaleAxis", "FuseOps"] ==
parse_pass_list_str("FoldScaleAxis,FuseOps")
with pytest.raises(argparse.ArgumentTypeError) as ate:
- tvmc.common.parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps")
+ parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps")
assert "MyYobaPass" in str(ate.value)
assert "MySuperYobaPass" in str(ate.value)
diff --git a/tests/python/driver/tvmc/test_registry_options.py
b/tests/python/driver/tvmc/test_registry_options.py
index 458d0a8..dbd7cc0 100644
--- a/tests/python/driver/tvmc/test_registry_options.py
+++ b/tests/python/driver/tvmc/test_registry_options.py
@@ -19,7 +19,7 @@ import argparse
import pytest
-from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc import TVMCException
from tvm.driver.tvmc.registry import generate_registry_args,
reconstruct_registry_entity
from tvm.relay.backend import Executor
diff --git a/tests/python/driver/tvmc/test_runner.py
b/tests/python/driver/tvmc/test_runner.py
index 2ce363a..30ce2c6 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -48,7 +48,7 @@ def test_generate_tensor_data_random():
def test_generate_tensor_data__type_unknown():
- with pytest.raises(tvmc.common.TVMCException) as e:
+ with pytest.raises(tvmc.TVMCException) as e:
tvmc.runner.generate_tensor_data((2, 3), "float32", "whatever")
diff --git a/tests/python/driver/tvmc/test_shape_parser.py
b/tests/python/driver/tvmc/test_shape_parser.py
index f49d89a..1e3cde1 100644
--- a/tests/python/driver/tvmc/test_shape_parser.py
+++ b/tests/python/driver/tvmc/test_shape_parser.py
@@ -19,19 +19,19 @@ import argparse
import pytest
-from tvm.driver import tvmc
+from tvm.driver.tvmc.shape_parser import parse_shape_string
def test_shape_parser():
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}
def test_alternate_syntax():
shape_string = "input:0:[10,10,10] input2:[20,20,20,20]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]}
@@ -44,14 +44,14 @@ def test_alternate_syntax():
],
)
def test_alternate_syntaxes(shape_string):
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
def test_negative_dimensions():
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"
@@ -59,7 +59,7 @@ def test_negative_dimensions():
def test_multiple_valid_gpu_inputs():
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected
@@ -67,19 +67,19 @@ def test_multiple_valid_gpu_inputs():
def test_invalid_pattern():
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
- tvmc.common.parse_shape_string(shape_string)
+ parse_shape_string(shape_string)
def test_invalid_separators():
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
- tvmc.common.parse_shape_string(shape_string)
+ parse_shape_string(shape_string)
def test_invalid_colon():
shape_string = "gpu_0/data_0:5,10 :test:10,10"
with pytest.raises(argparse.ArgumentTypeError):
- tvmc.common.parse_shape_string(shape_string)
+ parse_shape_string(shape_string)
@pytest.mark.parametrize(
@@ -93,11 +93,11 @@ def test_invalid_colon():
)
def test_invalid_slashes(shape_string):
with pytest.raises(argparse.ArgumentTypeError):
- tvmc.common.parse_shape_string(shape_string)
+ parse_shape_string(shape_string)
def test_dot():
# Check dot in input name
shape_string = "input.1:[10,10,10]"
- shape_dict = tvmc.common.parse_shape_string(shape_string)
+ shape_dict = parse_shape_string(shape_string)
assert shape_dict == {"input.1": [10, 10, 10]}
diff --git a/tests/python/driver/tvmc/test_target.py
b/tests/python/driver/tvmc/test_target.py
index 06db5c4..532ecbe 100644
--- a/tests/python/driver/tvmc/test_target.py
+++ b/tests/python/driver/tvmc/test_target.py
@@ -17,33 +17,32 @@
import pytest
-from tvm.driver import tvmc
-
-from tvm.driver.tvmc.common import TVMCException
+from tvm.driver.tvmc import TVMCException
+from tvm.driver.tvmc.target import target_from_cli, tokenize_target,
parse_target
def test_target_from_cli__error_duplicate():
with pytest.raises(TVMCException):
- _ = tvmc.common.target_from_cli("llvm, llvm")
+ _ = target_from_cli("llvm, llvm")
def test_target_invalid_more_than_two_tvm_targets():
with pytest.raises(TVMCException):
- _ = tvmc.common.target_from_cli("cuda, opencl, llvm")
+ _ = target_from_cli("cuda, opencl, llvm")
def test_target_from_cli__error_target_not_found():
with pytest.raises(TVMCException):
- _ = tvmc.common.target_from_cli("invalidtarget")
+ _ = target_from_cli("invalidtarget")
def test_target_from_cli__error_no_tvm_target():
with pytest.raises(TVMCException):
- _ = tvmc.common.target_from_cli("ethos-n77")
+ _ = target_from_cli("ethos-n77")
def test_target_two_tvm_targets():
- tvm_target, extra_targets = tvmc.common.target_from_cli(
+ tvm_target, extra_targets = target_from_cli(
"opencl -device=mali, llvm -mtriple=aarch64-linux-gnu"
)
@@ -55,7 +54,7 @@ def test_target_two_tvm_targets():
def test_tokenize_target_with_opts():
- tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar
-opt2=value2")
+ tokens = tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2")
expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar",
"-opt2=value2"]
assert len(tokens) == len(expected_tokens)
@@ -63,7 +62,7 @@ def test_tokenize_target_with_opts():
def test_tokenize_target_with_plus_sign():
- tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar
-opt2=test,+v")
+ tokens = tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v")
expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar",
"-opt2=test,+v"]
assert len(tokens) == len(expected_tokens)
@@ -71,7 +70,7 @@ def test_tokenize_target_with_plus_sign():
def test_tokenize_target_with_commas():
- tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag")
+ tokens = tokenize_target("foo -opt1=v,a,l,u,e,1 --flag")
expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"]
assert len(tokens) == len(expected_tokens)
@@ -79,7 +78,7 @@ def test_tokenize_target_with_commas():
def test_tokenize_target_with_commas_and_single_quotes():
- tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar")
+ tokens = tokenize_target("foo -opt1='v, a, l, u, e', bar")
expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"]
assert len(tokens) == len(expected_tokens)
@@ -87,7 +86,7 @@ def test_tokenize_target_with_commas_and_single_quotes():
def test_tokenize_target_with_commas_and_double_quotes():
- tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar')
+ tokens = tokenize_target('foo -opt1="v, a, l, u, e", bar')
expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"]
assert len(tokens) == len(expected_tokens)
@@ -95,7 +94,7 @@ def test_tokenize_target_with_commas_and_double_quotes():
def test_tokenize_target_with_dashes():
- tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz")
+ tokens = tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz")
expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"]
assert len(tokens) == len(expected_tokens)
@@ -103,7 +102,7 @@ def test_tokenize_target_with_dashes():
def test_parse_single_target_with_opts():
- targets = tvmc.common.parse_target("llvm -device=arm_cpu -mattr=+fp")
+ targets = parse_target("llvm -device=arm_cpu -mattr=+fp")
assert len(targets) == 1
assert "device" in targets[0]["opts"]
@@ -111,7 +110,7 @@ def test_parse_single_target_with_opts():
def test_parse_multiple_target():
- targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu")
+ targets = parse_target("compute-library, llvm -device=arm_cpu")
assert len(targets) == 2
assert "compute-library" == targets[0]["name"]
@@ -120,7 +119,7 @@ def test_parse_multiple_target():
def test_parse_hybrid_target():
"""Hybrid Target and external codegen"""
- targets = tvmc.common.parse_target(
+ targets = parse_target(
"cmsis-nn -accelerator_config=ethos-u55-256, llvm -device=arm_cpu
--system-lib"
)
@@ -132,9 +131,9 @@ def test_parse_hybrid_target():
def test_parse_quotes_and_separators_on_options():
- targets_no_quote = tvmc.common.parse_target("foo
-option1=+v1.0x,+value,+bar")
- targets_single_quote = tvmc.common.parse_target("foo
-option1='+v1.0x,+value'")
- targets_double_quote = tvmc.common.parse_target('foo
-option1="+v1.0x,+value"')
+ targets_no_quote = parse_target("foo -option1=+v1.0x,+value,+bar")
+ targets_single_quote = parse_target("foo -option1='+v1.0x,+value'")
+ targets_double_quote = parse_target('foo -option1="+v1.0x,+value"')
assert len(targets_no_quote) == 1
assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"]
@@ -147,7 +146,7 @@ def test_parse_quotes_and_separators_on_options():
def test_parse_multiple_target_with_opts_ethos_n77():
- targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm
-device=arm_cpu --system-lib")
+ targets = parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu
--system-lib")
assert len(targets) == 2
assert "ethos-n77" == targets[0]["name"]
@@ -157,7 +156,7 @@ def test_parse_multiple_target_with_opts_ethos_n77():
def test_parse_multiple_target_with_opts_ethos_n78():
- targets = tvmc.common.parse_target("ethos-n78 -myopt=value, llvm
-device=arm_cpu --system-lib")
+ targets = parse_target("ethos-n78 -myopt=value, llvm -device=arm_cpu
--system-lib")
assert len(targets) == 2
assert "ethos-n78" == targets[0]["name"]
diff --git a/tests/python/driver/tvmc/test_target_options.py
b/tests/python/driver/tvmc/test_target_options.py
index b592d50..1bcad48 100644
--- a/tests/python/driver/tvmc/test_target_options.py
+++ b/tests/python/driver/tvmc/test_target_options.py
@@ -19,9 +19,8 @@ import argparse
import pytest
-from tvm.driver import tvmc
-from tvm.driver.tvmc.common import TVMCException
-from tvm.driver.tvmc.target import generate_target_args,
reconstruct_target_args
+from tvm.driver.tvmc import TVMCException
+from tvm.driver.tvmc.target import generate_target_args,
reconstruct_target_args, target_from_cli
def test_target_to_argparse():
@@ -53,13 +52,13 @@ def test_skip_target_from_codegen():
def test_target_recombobulation_single():
- tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu":
"cortex-m3"}})
+ tvm_target, _ = target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}})
assert str(tvm_target) == "llvm -keys=cpu -link-params=0 -mcpu=cortex-m3"
def test_target_recombobulation_many():
- tvm_target, _ = tvmc.common.target_from_cli(
+ tvm_target, _ = target_from_cli(
"opencl -device=mali, llvm -mtriple=aarch64-linux-gnu",
{"llvm": {"mcpu": "cortex-m3"}, "opencl": {"max_num_threads": 404}},
)
@@ -75,7 +74,7 @@ def test_error_if_target_missing():
TVMCException,
match="Passed --target-opencl-max_num_threads but did not specify
opencl target",
):
- tvmc.common.target_from_cli(
+ target_from_cli(
"llvm",
{"opencl": {"max_num_threads": 404}},
)
diff --git a/tests/python/driver/tvmc/test_tracker.py
b/tests/python/driver/tvmc/test_tracker.py
index 2ca0fae..8734ad5 100644
--- a/tests/python/driver/tvmc/test_tracker.py
+++ b/tests/python/driver/tvmc/test_tracker.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from tvm.driver import tvmc
+from tvm.driver.tvmc.tracker import tracker_host_port_from_cli
def test_tracker_host_port_from_cli__hostname_port():
@@ -23,7 +23,7 @@ def test_tracker_host_port_from_cli__hostname_port():
expected_host = "1.2.3.4"
expected_port = 9090
- actual_host, actual_port =
tvmc.common.tracker_host_port_from_cli(input_str)
+ actual_host, actual_port = tracker_host_port_from_cli(input_str)
assert expected_host == actual_host
assert expected_port == actual_port
@@ -32,7 +32,7 @@ def test_tracker_host_port_from_cli__hostname_port():
def test_tracker_host_port_from_cli__hostname_port__empty():
input_str = ""
- actual_host, actual_port =
tvmc.common.tracker_host_port_from_cli(input_str)
+ actual_host, actual_port = tracker_host_port_from_cli(input_str)
assert actual_host is None
assert actual_port is None
@@ -43,7 +43,7 @@ def
test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
expected_host = "1.2.3.4"
expected_port = 9090
- actual_host, actual_port =
tvmc.common.tracker_host_port_from_cli(input_str)
+ actual_host, actual_port = tracker_host_port_from_cli(input_str)
assert expected_host == actual_host
assert expected_port == actual_port