comaniac commented on a change in pull request #7028:
URL: https://github.com/apache/tvm/pull/7028#discussion_r535668480



##########
File path: python/tvm/auto_scheduler/search_task.py
##########
@@ -42,18 +153,124 @@ class SearchTask(Object):
         The target host device of this search task.
     hardware_params : Optional[HardwareParams]
         Hardware parameters used in this search task.
+
+    Examples
+    --------
+    .. code-block:: python
+
+      # We support two ways to create a search task
+
+      # Way 1: create a task by a workload generation function.
+      # The `workload_func` is a function decorated by 
@auto_scheduler.register_workload
+      task = SearchTask(func=workload_func, args=args, target=target)
+
+      # Way 2: create a task by a workload_key.
+      # The `workload_key` is a string, which can be either a hash key or a 
json-serialized
+      # tuple(func, args).
+      task = SearchTask(workload_key=workload_key, target=target)
     """
 
-    def __init__(self, dag, workload_key, target, target_host=None, 
hardware_params=None):
-        self.dag = dag
+    def __init__(
+        self,
+        func=None,
+        args=None,
+        compute_dag=None,
+        workload_key=None,
+        target=None,
+        target_host=None,
+        hardware_params=None,
+    ):
+        assert (
+            func is not None or workload_key is not None
+        ), "Either a workload generation function or a workload key should be 
provided"
+
+        if func is not None:
+            workload_key = make_workload_key(func, args)
+        if compute_dag is None:
+            compute_dag = ComputeDAG(workload_key)
+
+        assert target is not None, "Must specify a target."
+        if isinstance(target, str):
+            target = Target(target)
+        if isinstance(target_host, str):
+            target_host = Target(target_host)
+
+        self.dag = compute_dag
         self.workload_key = workload_key
         self.target = target
         self.target_host = target_host
         self.hardware_params = hardware_params
         self.__init_handle_by_constructor__(
-            _ffi_api.SearchTask, dag, workload_key, target, target_host, 
hardware_params
+            _ffi_api.SearchTask, compute_dag, workload_key, target, 
target_host, hardware_params
         )
 
+    def tune(self, tuning_options, search_policy=None):
+        """Run auto scheduling search for a task
+
+        Parameters
+        ----------
+        tuning_options : Optional[TuningOptions]
+            Tuning and measurement options.
+        search_policy : Optional[SearchPolicy]
+            The search policy to be used for schedule search.
+        """
+        if search_policy is None:
+            cost_model = XGBModel()
+            search_policy = SketchPolicy(self, cost_model)
+
+        _ffi_api.AutoSchedule(search_policy, tuning_options)
+
+    def apply_best(self, log_file, layout_rewrite_option=None):
+        """Apply the history best from a log file and return the schedule.
+
+        Parameters
+        ----------
+        log_file : str
+           The name of the log file
+        layout_rewrite_option : Optional[LayoutRewriteOption]
+           The layout rewrite option
+
+        Returns
+        -------
+            A `te.Schedule` and the a list of `te.Tensor` to be used in 
`tvm.lower` or `tvm.build`.
+        """
+        inp, res = load_best_record(log_file, self.workload_key)
+
+        if layout_rewrite_option is None:
+            layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
+            if self.target.kind.name == "llvm":
+                layout_rewrite_option = 
LayoutRewriteOption.INSERT_TRANSFORM_STAGE
+        sch, args = self.compute_dag.apply_steps_from_state(inp.state, 
layout_rewrite_option)
+        return sch, args
+
+    def print_best(self, log_file, print_mode="schedule"):
+        """Print the best schedule as python schedule API code or CUDA source 
code.
+
+        Parameters
+        ----------
+        log_file : str
+           The name of the log file
+        print_mode: str
+           if "schedule", print the best schedule as python schedule API code.
+           if "cude", print the best schedule as CUDA source code.

Review comment:
       - s/cude/cuda
   - This looks inconsistent to "schedule". Maybe just name it "code" and throw 
RuntimeError if the target is not CUDA?
   

##########
File path: tutorials/auto_scheduler/tune_matmul_x86.py
##########
@@ -147,22 +156,10 @@ def matmul_add(N, L, M, dtype):
 # file "matmul.json". The measurement records can be used to re-apply search 
results,
 # resume the search, and perform other analyses.
 
-######################################################################
-# Here is an example where we load the best schedule from a file,
-# print the equivalent python schedule API, and build the binary again.
-
-# Load the measuremnt record for the best schedule
-inp, res = auto_scheduler.load_best(log_file, task.workload_key)

Review comment:
       Better to remind readers that we use `task.apply_best` to load and apply 
the best schedule, or it looks a bit weird to only see how to print the best 
schedule here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to