comaniac commented on a change in pull request #8253:
URL: https://github.com/apache/tvm/pull/8253#discussion_r651120543
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -415,3 +415,86 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape
return shape_dict
+
+
+def set_config_value(name, value, config_type):
+ """Set a PassContext configuration value according to its value"""
+
+ if config_type == "IntImm":
+ # "Bool" configurations in the PassContext are recognized as
+ # IntImm, so deal with this case here
+ mapping_values = {
+ "false": False,
+ "true": True,
+ }
+
+ if value.isdigit():
+ parsed_value = int(value)
+ else:
+ # if not an int, accept only values on the mapping table, case
insensitive
+ parsed_value = mapping_values.get(value.lower(), None)
+
+ if parsed_value is None:
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+
+ if config_type == "runtime.String":
+ parsed_value = value
+
+ return parsed_value
+
+
+def parse_configs(input_configs):
+ """Parse configuration values set via command line.
+
+ Parameters
+ ----------
+ input_configs: list of str
+ list of configurations provided via command line.
+
+ Returns
+ -------
+ pass_context_configs: dict
+ a dict containing key-value configs to be used in the PassContext.
+ """
+ all_configs = tvm.ir.transform.PassContext.list_configs()
+ supported_config_types = ("IntImm", "runtime.String")
+ supported_configs = [
+ name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
+ ]
+ pass_context_configs = {}
+
+ if not input_configs:
+ return {}
Review comment:
Move this to the beginning of this function so that you don't need to
process all available configs if users don't specify any.
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -415,3 +415,86 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape
return shape_dict
+
+
+def set_config_value(name, value, config_type):
+ """Set a PassContext configuration value according to its value"""
+
+ if config_type == "IntImm":
+ # "Bool" configurations in the PassContext are recognized as
+ # IntImm, so deal with this case here
+ mapping_values = {
+ "false": False,
+ "true": True,
+ }
+
+ if value.isdigit():
+ parsed_value = int(value)
+ else:
+ # if not an int, accept only values on the mapping table, case
insensitive
+ parsed_value = mapping_values.get(value.lower(), None)
+
+ if parsed_value is None:
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+
+ if config_type == "runtime.String":
+ parsed_value = value
+
+ return parsed_value
+
+
+def parse_configs(input_configs):
+ """Parse configuration values set via command line.
+
+ Parameters
+ ----------
+ input_configs: list of str
+ list of configurations provided via command line.
+
+ Returns
+ -------
+ pass_context_configs: dict
+ a dict containing key-value configs to be used in the PassContext.
+ """
+ all_configs = tvm.ir.transform.PassContext.list_configs()
+ supported_config_types = ("IntImm", "runtime.String")
+ supported_configs = [
+ name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
+ ]
+ pass_context_configs = {}
+
+ if not input_configs:
+ return {}
+
+ for config in input_configs:
+ if len(config) == 0:
Review comment:
```suggestion
if not config:
```
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -42,6 +42,13 @@ def add_compile_parser(subparsers):
parser = subparsers.add_parser("compile", help="compile a model.")
parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--config",
Review comment:
Would `build-config` or `pass-config` more intuitive?
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -415,3 +415,86 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape
return shape_dict
+
+
+def set_config_value(name, value, config_type):
+ """Set a PassContext configuration value according to its value"""
+
+ if config_type == "IntImm":
+ # "Bool" configurations in the PassContext are recognized as
+ # IntImm, so deal with this case here
+ mapping_values = {
+ "false": False,
+ "true": True,
+ }
+
+ if value.isdigit():
+ parsed_value = int(value)
+ else:
+ # if not an int, accept only values on the mapping table, case
insensitive
+ parsed_value = mapping_values.get(value.lower(), None)
+
+ if parsed_value is None:
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+
+ if config_type == "runtime.String":
+ parsed_value = value
Review comment:
```suggestion
parsed_value = value
if config_type == "IntImm":
if value.isdigit():
parsed_value = int(value)
else:
# must be boolean values if not an int
try:
parsed_value = bool(distutils.util.strtobool(value))
except ValueError as err:
raise TVMCException(f"Invalid value '{value}' for
configuration '{name}'. ")
```
If the dependency of `distuilts` is a concen, the following also works:
```
try:
parsed_value = json.loads(value)
except json.decoder.JSONDecodeError as err:
...
```
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -415,3 +415,86 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape
return shape_dict
+
+
+def set_config_value(name, value, config_type):
+ """Set a PassContext configuration value according to its value"""
+
+ if config_type == "IntImm":
+ # "Bool" configurations in the PassContext are recognized as
+ # IntImm, so deal with this case here
+ mapping_values = {
+ "false": False,
+ "true": True,
+ }
+
+ if value.isdigit():
+ parsed_value = int(value)
+ else:
+ # if not an int, accept only values on the mapping table, case
insensitive
+ parsed_value = mapping_values.get(value.lower(), None)
+
+ if parsed_value is None:
+ raise TVMCException(f"Invalid value '{value}' for configuration
'{name}'. ")
+
+ if config_type == "runtime.String":
+ parsed_value = value
+
+ return parsed_value
+
+
+def parse_configs(input_configs):
+ """Parse configuration values set via command line.
+
+ Parameters
+ ----------
+ input_configs: list of str
+ list of configurations provided via command line.
+
+ Returns
+ -------
+ pass_context_configs: dict
+ a dict containing key-value configs to be used in the PassContext.
+ """
+ all_configs = tvm.ir.transform.PassContext.list_configs()
+ supported_config_types = ("IntImm", "runtime.String")
+ supported_configs = [
+ name for name in all_configs.keys() if all_configs[name]["type"] in
supported_config_types
+ ]
+ pass_context_configs = {}
+
+ if not input_configs:
+ return {}
+
+ for config in input_configs:
+ if len(config) == 0:
+ raise TVMCException(
+ f"Invalid format for configuration '{config}', use
<config>=<value>"
+ )
+
+ # Each config is expected to be provided as "name=value"
+ try:
+ name, value = config.split("=")
+ name = name.strip()
+ value = value.strip()
+ except ValueError:
+ raise TVMCException(
+ f"Invalid format for configuration '{config}', use
<config>=<value>"
+ )
+
+ if name not in all_configs:
+ raise TVMCException(
+ f"Configuration '{name}' is not defined in TVM. "
+ f"These are the existing configurations: {',
'.join(all_configs)}"
+ )
+
+ if name not in supported_configs:
+ raise TVMCException(
+ f"Configuration '{name}' is not supported in TVMC. "
Review comment:
Better to explain the reason (i.e., not supported config value type).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]