jwfromm commented on a change in pull request #7823:
URL: https://github.com/apache/tvm/pull/7823#discussion_r615327011
##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -228,24 +240,137 @@ def drive_tune(args):
args: argparse.Namespace
Arguments from command line parser.
"""
- # extra arguments validation before importing the model, so that obvious
errors
- # are pointed in advance.
- if args.rpc_tracker:
- parsed_url = urlparse("//%s" % args.rpc_tracker)
+ tvmc_model = frontends.load_model(args.FILE, args.model_format,
shape_dict=args.input_shapes)
+ tvmc_model.tuning_records = args.tuning_records
+ # Specify hardware parameters, although they'll only be used if
autoscheduling.
+ hardware_params = auto_scheduler.HardwareParams(
+ args.num_cores,
+ args.vector_unit_bytes,
+ args.cache_line_bytes,
+ args.max_shared_memory_per_block,
+ args.max_local_memory_per_block,
+ args.max_threads_per_block,
+ args.max_vthread_extent,
+ args.warp_size,
+ args.target,
+ args.target_host,
+ )
+
+ tune_model(
+ tvmc_model,
+ args.target,
+ args.output,
+ args.enable_autoscheduler,
+ args.rpc_key,
+ args.rpc_tracker,
+ args.trials,
+ args.target_host,
+ args.tuner,
+ args.min_repeat_ms,
+ args.early_stopping,
+ args.desired_layout,
+ args.timeout,
+ args.number,
+ args.repeat,
+ args.parallel,
+ hardware_params,
+ args.include_simple_tasks,
+ args.log_estimated_latency,
+ )
+
+
+def tune_model(
+ tvmc_model: TVMCModel,
+ target: str,
+ tuning_records: Optional[str] = None,
+ enable_autoscheduler: bool = False,
+ rpc_key: Optional[str] = None,
+ rpc_tracker: Optional[str] = None,
+ trials: Optional[int] = None,
+ target_host: str = "llvm",
+ tuner: str = "xgb",
+ min_repeat_ms: Optional[int] = None,
+ early_stopping: Optional[int] = None,
+ desired_layout: Optional[str] = None,
+ timeout: int = 10,
+ number: int = 10,
+ repeat: int = 1,
+ parallel: int = 4,
+ hardware_params: Optional[HardwareParams] = None,
+ include_simple_tasks: bool = False,
+ log_estimated_latency: bool = False,
+):
+ """Use tuning to automatically optimize the functions in a model.
+
+ Parameters
+ ----------
+ tvmc_model : TVMCModel
+ The model to be optimized.
+ target : str
+ Compilation target as plain string, inline JSON or path to a JSON file.
+ tuning_records: str, optional
+ The path to a file that tuning results will be saved to. If not
specified,
+ a temporary file will be used.
+ enable_autoscheduler : bool, optional
+ When true, use autoscheduling rather than autotvm. This should produce
+ faster kernels for compatible model-target pairs.
+ rpc_key : str, optional
+ The RPC tracker key of the target device. Required when rpc_tracker is
provided.
+ rpc_tracker : str, optional
+ The hostname and port (optional, defaults to 9090) of the RPC tracker,
+ e.g. 192.168.0.100:9999.
+ trials : int, optional
+ The number of schedules to try out. For autotvm, each task will have
this many
+ options explored. For autoscheduling, the total number of schedules
checked in
+ the entire model will be this many.
+ target_host : str, optional
+ The host compilation target, defaults to 'llvm'.
+ tuner : str, optional
+ The type of tuner to use when tuning with autotvm. Can be one of
+ "ga", "gridsearch", "random", "xgb", "xgb_knob", and "xgb-rank".
+ min_repeat_ms : int, optional
+ Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other
targets.
+ early_stopping : int, optional
+ When specified, stop tuning after this number of trials if results
aren't improving.
+ desired_layout : str, optional
+ Can be one of "NCHW" or "NHWC". When specified, the graph will be
converted to this layout.
+ timeout : int, optional,
+ If a kernel trial lasts longer than this duration in seconds, it will
be
+ considered a failure.
+ number : int, optional
+ The number of runs a single repeat is made of.
+ repeat : int, optional
+ How many times each measurement should be repeated.
+ parallel : int, optional
+ The maximum number of parallel devices to use when tuning.
+ hardware_params : auto_scheduler.HardwareParams, optional
+ When using the autoscheduler, this object defines the configuration of
the target hardware.
+ include_simple_tasks : bool, optional
+ Whether to extract simple operations or only computationally intensive
ones when using
+ the autoscheduler.
+ log_estimated_latency : bool, optional
+ If using the autoscheduler, write the estimated latency at each step
of tuning to file.
+
+ """
+ if rpc_tracker:
+ parsed_url = urlparse("//%s" % rpc_tracker)
rpc_hostname = parsed_url.hostname
rpc_port = parsed_url.port or 9090
logger.info("RPC tracker hostname: %s", rpc_hostname)
logger.info("RPC tracker port: %s", rpc_port)
- if not args.rpc_key:
+ if not rpc_key:
raise common.TVMCException(
"need to provide an RPC tracker key (--rpc-key) for remote
tuning"
)
- target, extra_targets = common.target_from_cli(args.target)
- target_host = args.target_host
- target, target_host = Target.check_and_update_host_consist(target,
target_host)
- mod, params = frontends.load_model(args.FILE, args.model_format,
shape_dict=args.input_shapes)
+ target, extra_targets = common.target_from_cli(target)
+ if target_host is not None:
Review comment:
TVMC command line still uses it so I would like to use it here as well.
I imagine we'll remove it eventually though.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]