comaniac commented on a change in pull request #7028:
URL: https://github.com/apache/tvm/pull/7028#discussion_r535668480
##########
File path: python/tvm/auto_scheduler/search_task.py
##########
@@ -42,18 +153,124 @@ class SearchTask(Object):
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # We support two ways to create a search task
+
+ # Way 1: create a task by a workload generation function.
+ # The `workload_func` is a function decorated by
@auto_scheduler.register_workload
+ task = SearchTask(func=workload_func, args=args, target=target)
+
+ # Way 2: create a task by a workload_key.
+ # The `workload_key` is a string, which can be either a hash key or a
json-serialized
+ # tuple(func, args).
+ task = SearchTask(workload_key=workload_key, target=target)
"""
- def __init__(self, dag, workload_key, target, target_host=None,
hardware_params=None):
- self.dag = dag
+ def __init__(
+ self,
+ func=None,
+ args=None,
+ compute_dag=None,
+ workload_key=None,
+ target=None,
+ target_host=None,
+ hardware_params=None,
+ ):
+ assert (
+ func is not None or workload_key is not None
+ ), "Either a workload generation function or a workload key should be
provided"
+
+ if func is not None:
+ workload_key = make_workload_key(func, args)
+ if compute_dag is None:
+ compute_dag = ComputeDAG(workload_key)
+
+ assert target is not None, "Must specify a target."
+ if isinstance(target, str):
+ target = Target(target)
+ if isinstance(target_host, str):
+ target_host = Target(target_host)
+
+ self.dag = compute_dag
self.workload_key = workload_key
self.target = target
self.target_host = target_host
self.hardware_params = hardware_params
self.__init_handle_by_constructor__(
- _ffi_api.SearchTask, dag, workload_key, target, target_host,
hardware_params
+ _ffi_api.SearchTask, compute_dag, workload_key, target,
target_host, hardware_params
)
+ def tune(self, tuning_options, search_policy=None):
+ """Run auto scheduling search for a task
+
+ Parameters
+ ----------
+ tuning_options : Optional[TuningOptions]
+ Tuning and measurement options.
+ search_policy : Optional[SearchPolicy]
+ The search policy to be used for schedule search.
+ """
+ if search_policy is None:
+ cost_model = XGBModel()
+ search_policy = SketchPolicy(self, cost_model)
+
+ _ffi_api.AutoSchedule(search_policy, tuning_options)
+
+ def apply_best(self, log_file, layout_rewrite_option=None):
+ """Apply the history best from a log file and return the schedule.
+
+ Parameters
+ ----------
+ log_file : str
+ The name of the log file
+ layout_rewrite_option : Optional[LayoutRewriteOption]
+ The layout rewrite option
+
+ Returns
+ -------
+ A `te.Schedule` and the a list of `te.Tensor` to be used in
`tvm.lower` or `tvm.build`.
+ """
+ inp, res = load_best_record(log_file, self.workload_key)
+
+ if layout_rewrite_option is None:
+ layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
+ if self.target.kind.name == "llvm":
+ layout_rewrite_option =
LayoutRewriteOption.INSERT_TRANSFORM_STAGE
+ sch, args = self.compute_dag.apply_steps_from_state(inp.state,
layout_rewrite_option)
+ return sch, args
+
+ def print_best(self, log_file, print_mode="schedule"):
+ """Print the best schedule as python schedule API code or CUDA source
code.
+
+ Parameters
+ ----------
+ log_file : str
+ The name of the log file
+ print_mode: str
+ if "schedule", print the best schedule as python schedule API code.
+ if "cude", print the best schedule as CUDA source code.
Review comment:
- s/cude/cuda
- This looks inconsistent to "schedule". Maybe just name it "code" and throw
RuntimeError if the target is not CUDA?
##########
File path: tutorials/auto_scheduler/tune_matmul_x86.py
##########
@@ -147,22 +156,10 @@ def matmul_add(N, L, M, dtype):
# file "matmul.json". The measurement records can be used to re-apply search
results,
# resume the search, and perform other analyses.
-######################################################################
-# Here is an example where we load the best schedule from a file,
-# print the equivalent python schedule API, and build the binary again.
-
-# Load the measuremnt record for the best schedule
-inp, res = auto_scheduler.load_best(log_file, task.workload_key)
Review comment:
Better to remind readers that we use `task.apply_best` to load and apply
the best schedule, or it looks a bit weird to only see how to print the best
schedule here.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]