jwfromm commented on a change in pull request #7823:
URL: https://github.com/apache/tvm/pull/7823#discussion_r615327011



##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -228,24 +240,137 @@ def drive_tune(args):
     args: argparse.Namespace
         Arguments from command line parser.
     """
-    # extra arguments validation before importing the model, so that obvious 
errors
-    # are pointed in advance.
-    if args.rpc_tracker:
-        parsed_url = urlparse("//%s" % args.rpc_tracker)
+    tvmc_model = frontends.load_model(args.FILE, args.model_format, 
shape_dict=args.input_shapes)
+    tvmc_model.tuning_records = args.tuning_records
+    # Specify hardware parameters, although they'll only be used if 
autoscheduling.
+    hardware_params = auto_scheduler.HardwareParams(
+        args.num_cores,
+        args.vector_unit_bytes,
+        args.cache_line_bytes,
+        args.max_shared_memory_per_block,
+        args.max_local_memory_per_block,
+        args.max_threads_per_block,
+        args.max_vthread_extent,
+        args.warp_size,
+        args.target,
+        args.target_host,
+    )
+
+    tune_model(
+        tvmc_model,
+        args.target,
+        args.output,
+        args.enable_autoscheduler,
+        args.rpc_key,
+        args.rpc_tracker,
+        args.trials,
+        args.target_host,
+        args.tuner,
+        args.min_repeat_ms,
+        args.early_stopping,
+        args.desired_layout,
+        args.timeout,
+        args.number,
+        args.repeat,
+        args.parallel,
+        hardware_params,
+        args.include_simple_tasks,
+        args.log_estimated_latency,
+    )
+
+
+def tune_model(
+    tvmc_model: TVMCModel,
+    target: str,
+    tuning_records: Optional[str] = None,
+    enable_autoscheduler: bool = False,
+    rpc_key: Optional[str] = None,
+    rpc_tracker: Optional[str] = None,
+    trials: Optional[int] = None,
+    target_host: str = "llvm",
+    tuner: str = "xgb",
+    min_repeat_ms: Optional[int] = None,
+    early_stopping: Optional[int] = None,
+    desired_layout: Optional[str] = None,
+    timeout: int = 10,
+    number: int = 10,
+    repeat: int = 1,
+    parallel: int = 4,
+    hardware_params: Optional[HardwareParams] = None,
+    include_simple_tasks: bool = False,
+    log_estimated_latency: bool = False,
+):
+    """Use tuning to automatically optimize the functions in a model.
+
+    Parameters
+    ----------
+    tvmc_model : TVMCModel
+        The model to be optimized.
+    target : str
+        Compilation target as plain string, inline JSON or path to a JSON file.
+    tuning_records: str, optional
+        The path to a file that tuning results will be saved to. If not 
specified,
+        a temporary file will be used.
+    enable_autoscheduler : bool, optional
+        When true, use autoscheduling rather than autotvm. This should produce
+        faster kernels for compatible model-target pairs.
+    rpc_key : str, optional
+        The RPC tracker key of the target device. Required when rpc_tracker is 
provided.
+    rpc_tracker : str, optional
+        The hostname and port (optional, defaults to 9090) of the RPC tracker,
+        e.g. 192.168.0.100:9999.
+    trials : int, optional
+        The number of schedules to try out. For autotvm, each task will have 
this many
+        options explored. For autoscheduling, the total number of schedules 
checked in
+        the entire model will be this many.
+    target_host : str, optional
+        The host compilation target, defaults to 'llvm'.
+    tuner : str, optional
+        The type of tuner to use when tuning with autotvm. Can be one of
+        "ga", "gridsearch", "random", "xgb", "xgb_knob", and "xgb-rank".
+    min_repeat_ms : int, optional
+        Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other 
targets.
+    early_stopping : int, optional
+        When specified, stop tuning after this number of trials if results 
aren't improving.
+    desired_layout : str, optional
+        Can be one of "NCHW" or "NHWC". When specified, the graph will be 
converted to this layout.
+    timeout : int, optional,
+        If a kernel trial lasts longer than this duration in seconds, it will 
be
+        considered a failure.
+    number : int, optional
+        The number of runs a single repeat is made of.
+    repeat : int, optional
+        How many times each measurement should be repeated.
+    parallel : int, optional
+        The maximum number of parallel devices to use when tuning.
+    hardware_params : auto_scheduler.HardwareParams, optional
+        When using the autoscheduler, this object defines the configuration of 
the target hardware.
+    include_simple_tasks : bool, optional
+        Whether to extract simple operations or only computationally intensive 
ones when using
+        the autoscheduler.
+    log_estimated_latency : bool, optional
+        If using the autoscheduler, write the estimated latency at each step 
of tuning to file.
+
+    """
+    if rpc_tracker:
+        parsed_url = urlparse("//%s" % rpc_tracker)
         rpc_hostname = parsed_url.hostname
         rpc_port = parsed_url.port or 9090
         logger.info("RPC tracker hostname: %s", rpc_hostname)
         logger.info("RPC tracker port: %s", rpc_port)
 
-        if not args.rpc_key:
+        if not rpc_key:
             raise common.TVMCException(
                 "need to provide an RPC tracker key (--rpc-key) for remote 
tuning"
             )
 
-    target, extra_targets = common.target_from_cli(args.target)
-    target_host = args.target_host
-    target, target_host = Target.check_and_update_host_consist(target, 
target_host)
-    mod, params = frontends.load_model(args.FILE, args.model_format, 
shape_dict=args.input_shapes)
+    target, extra_targets = common.target_from_cli(target)
+    if target_host is not None:

Review comment:
       TVMC command line still uses it so I would like to use it here as well. 
I imagine we'll remove it eventually though.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to