jwfromm commented on a change in pull request #7823:
URL: https://github.com/apache/tvm/pull/7823#discussion_r617147477
##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -228,24 +240,137 @@ def drive_tune(args):
args: argparse.Namespace
Arguments from command line parser.
"""
- # extra arguments validation before importing the model, so that obvious
errors
- # are pointed in advance.
- if args.rpc_tracker:
- parsed_url = urlparse("//%s" % args.rpc_tracker)
+ tvmc_model = frontends.load_model(args.FILE, args.model_format,
shape_dict=args.input_shapes)
+ tvmc_model.tuning_records = args.tuning_records
+ # Specify hardware parameters, although they'll only be used if
autoscheduling.
+ hardware_params = auto_scheduler.HardwareParams(
+ args.num_cores,
+ args.vector_unit_bytes,
+ args.cache_line_bytes,
+ args.max_shared_memory_per_block,
+ args.max_local_memory_per_block,
+ args.max_threads_per_block,
+ args.max_vthread_extent,
+ args.warp_size,
+ args.target,
+ args.target_host,
+ )
+
+ tune_model(
+ tvmc_model,
+ args.target,
+ args.output,
+ args.enable_autoscheduler,
+ args.rpc_key,
+ args.rpc_tracker,
+ args.trials,
+ args.target_host,
+ args.tuner,
+ args.min_repeat_ms,
+ args.early_stopping,
+ args.desired_layout,
+ args.timeout,
+ args.number,
+ args.repeat,
+ args.parallel,
+ hardware_params,
+ args.include_simple_tasks,
+ args.log_estimated_latency,
+ )
+
+
+def tune_model(
+ tvmc_model: TVMCModel,
+ target: str,
+ tuning_records: Optional[str] = None,
+ enable_autoscheduler: bool = False,
+ rpc_key: Optional[str] = None,
+ rpc_tracker: Optional[str] = None,
+ trials: Optional[int] = None,
+ target_host: str = "llvm",
+ tuner: str = "xgb",
+ min_repeat_ms: Optional[int] = None,
+ early_stopping: Optional[int] = None,
+ desired_layout: Optional[str] = None,
+ timeout: int = 10,
+ number: int = 10,
+ repeat: int = 1,
+ parallel: int = 4,
+ hardware_params: Optional[HardwareParams] = None,
+ include_simple_tasks: bool = False,
+ log_estimated_latency: bool = False,
+):
+ """Use tuning to automatically optimize the functions in a model.
+
+ Parameters
+ ----------
+ tvmc_model : TVMCModel
+ The model to be optimized.
+ target : str
+ Compilation target as plain string, inline JSON or path to a JSON file.
+ tuning_records: str, optional
+ The path to a file that tuning results will be saved to. If not
specified,
+ a temporary file will be used.
+ enable_autoscheduler : bool, optional
+ When true, use autoscheduling rather than autotvm. This should produce
+ faster kernels for compatible model-target pairs.
+ rpc_key : str, optional
+ The RPC tracker key of the target device. Required when rpc_tracker is
provided.
+ rpc_tracker : str, optional
+ The hostname and port (optional, defaults to 9090) of the RPC tracker,
+ e.g. 192.168.0.100:9999.
+ trials : int, optional
+ The number of schedules to try out. For autotvm, each task will have
this many
+ options explored. For autoscheduling, the total number of schedules
checked in
+ the entire model will be this many.
+ target_host : str, optional
+ The host compilation target, defaults to 'llvm'.
+ tuner : str, optional
+ The type of tuner to use when tuning with autotvm. Can be one of
+ "ga", "gridsearch", "random", "xgb", "xgb_knob", and "xgb-rank".
+ min_repeat_ms : int, optional
+ Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other
targets.
+ early_stopping : int, optional
+ When specified, stop tuning after this number of trials if results
aren't improving.
+ desired_layout : str, optional
+ Can be one of "NCHW" or "NHWC". When specified, the graph will be
converted to this layout.
+ timeout : int, optional,
+ If a kernel trial lasts longer than this duration in seconds, it will
be
+ considered a failure.
+ number : int, optional
Review comment:
I agree with you and we will definitely do this but dont want to add it
to this PR, it's already too big.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]