areusch commented on a change in pull request #7823:
URL: https://github.com/apache/tvm/pull/7823#discussion_r616956461



##########
File path: python/tvm/driver/tvmc/autotuner.py
##########
@@ -228,24 +240,137 @@ def drive_tune(args):
     args: argparse.Namespace
         Arguments from command line parser.
     """
-    # extra arguments validation before importing the model, so that obvious 
errors
-    # are pointed in advance.
-    if args.rpc_tracker:
-        parsed_url = urlparse("//%s" % args.rpc_tracker)
+    tvmc_model = frontends.load_model(args.FILE, args.model_format, 
shape_dict=args.input_shapes)
+    tvmc_model.tuning_records = args.tuning_records
+    # Specify hardware parameters, although they'll only be used if 
autoscheduling.
+    hardware_params = auto_scheduler.HardwareParams(
+        args.num_cores,
+        args.vector_unit_bytes,
+        args.cache_line_bytes,
+        args.max_shared_memory_per_block,
+        args.max_local_memory_per_block,
+        args.max_threads_per_block,
+        args.max_vthread_extent,
+        args.warp_size,
+        args.target,
+        args.target_host,
+    )
+
+    tune_model(
+        tvmc_model,
+        args.target,
+        args.output,
+        args.enable_autoscheduler,
+        args.rpc_key,
+        args.rpc_tracker,
+        args.trials,
+        args.target_host,
+        args.tuner,
+        args.min_repeat_ms,
+        args.early_stopping,
+        args.desired_layout,
+        args.timeout,
+        args.number,
+        args.repeat,
+        args.parallel,
+        hardware_params,
+        args.include_simple_tasks,
+        args.log_estimated_latency,
+    )
+
+
+def tune_model(
+    tvmc_model: TVMCModel,
+    target: str,
+    tuning_records: Optional[str] = None,
+    enable_autoscheduler: bool = False,
+    rpc_key: Optional[str] = None,
+    rpc_tracker: Optional[str] = None,
+    trials: Optional[int] = None,
+    target_host: str = "llvm",
+    tuner: str = "xgb",
+    min_repeat_ms: Optional[int] = None,
+    early_stopping: Optional[int] = None,
+    desired_layout: Optional[str] = None,
+    timeout: int = 10,
+    number: int = 10,
+    repeat: int = 1,
+    parallel: int = 4,
+    hardware_params: Optional[HardwareParams] = None,
+    include_simple_tasks: bool = False,
+    log_estimated_latency: bool = False,
+):
+    """Use tuning to automatically optimize the functions in a model.
+
+    Parameters
+    ----------
+    tvmc_model : TVMCModel
+        The model to be optimized.
+    target : str
+        Compilation target as plain string, inline JSON or path to a JSON file.
+    tuning_records: str, optional
+        The path to a file that tuning results will be saved to. If not 
specified,
+        a temporary file will be used.

Review comment:
       I think the correct way to solve this is to not require tuning records 
to be written to a file in deeper APIs, but keep the requirement at this level. 
we could do this with passing a stream instead of a file path to the tune 
callback.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to