junrushao commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r988422820


##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -69,47 +124,229 @@ def extract_task_from_relay(
     """
     # pylint: disable=import-outside-toplevel
     from tvm import autotvm
-    from tvm.relay import Function as RelayFunc
 
     # pylint: enable=import-outside-toplevel
+    mod, target, params = _normalize_params(mod, target, params)
+    pass_config = dict(pass_config)
+    if target.kind.name != "cuda" and isinstance(
+        autotvm.DispatchContext.current, autotvm.FallbackContext
+    ):
+        tophub_context = autotvm.tophub.context(target)
+    else:
+        tophub_context = autotvm.utils.EmptyContext()
+    with Profiler.timeit("TaskExtraction"):
+        with target, _autotvm_silencer(), tophub_context:
+            with transform.PassContext(
+                opt_level=opt_level,
+                config=pass_config,
+            ):
+                return list(_extract_task(mod, target, params))
+
+
+def extracted_tasks_to_tune_contexts(
+    extracted_tasks: List[ExtractedTask],
+    work_dir: str,
+    space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
+    strategy: SearchStrategy.SearchStrategyType = "evolutionary",
+    num_threads: Union[Literal["physical", "logical"], int] = "physical",
+    seed: Optional[int] = None,
+) -> Tuple[List[TuneContext], List[float]]:
+    """Convert ExtractedTask to TuneContext.
+
+    Parameters
+    ----------
+    tasks : List[ExtractedTask]
+        The tasks to be converted
+    work_dir : str
+        The working directory to store logs and databases
+    space : SpaceGenerator.SpaceGeneratorType
+        The space generator to use.
+    strategy : SearchStrategy.SearchStrategyType
+        The search strategy to use.
+    num_threads : Union[Literal["physical", "logical"], int]
+        The number of threads to use.
+    seed : Optional[int]
+        The random seed to use.
+
+    Returns
+    -------
+    tasks : List[TuneContext]
+        The converted tasks
+    task_weights : List[float]
+        The weights of the tasks
+    """
+    tasks: List[TuneContext] = []
+    task_weights: List[float] = []
+    for task, logger, rand_state in zip(
+        extracted_tasks,
+        get_loggers_from_work_dir(work_dir, [t.task_name for t in 
extracted_tasks]),
+        fork_seed(seed, n=len(extracted_tasks)),
+    ):
+        tasks.append(
+            TuneContext(
+                mod=task.dispatched[0],
+                target=task.target,
+                space_generator=space,
+                search_strategy=strategy,
+                task_name=task.task_name,
+                logger=logger,
+                rand_state=rand_state,
+                num_threads=num_threads,
+            ).clone()
+        )
+        task_weights.append(task.weight)
+    return tasks, task_weights
+
 
-    extract_task_func = get_global_func(
-        "relay.backend.MetaScheduleExtractTask",
-        allow_missing=False,
+def tune_relay(
+    mod: IRModule,
+    params: Dict[str, NDArray],
+    target: Union[str, Target],
+    work_dir: str,
+    max_trials_global: int,
+    *,
+    max_trials_per_task: Optional[int] = None,
+    num_trials_per_iter: int = 64,
+    builder: Builder.BuilderType = "local",
+    runner: Runner.RunnerType = "local",
+    database: Database.DatabaseType = "json",
+    cost_model: CostModel.CostModelType = "xgb",
+    measure_callbacks: MeasureCallback.CallbackListType = "default",
+    task_scheduler: TaskScheduler.TaskSchedulerType = "gradient",
+    space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
+    strategy: SearchStrategy.SearchStrategyType = "evolutionary",
+    seed: Optional[int] = None,
+) -> Database:
+    """Tune a Relay program.
+
+    Parameters
+    ----------
+    mod : Union[IRModule, tir.PrimFunc]
+        The module or function to tune
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        The associated parameters of the program
+    target : Union[Target, str]
+        The compilation target
+    work_dir : str
+        The working directory to store the tuning records
+    max_trials_global : int
+        The maximum number of trials to run
+    max_trials_per_task : Optional[int]
+        The maximum number of trials to run for each task
+    num_trials_per_iter : int
+        The number of trials to run per iteration
+    builder : BuilderType
+        The builder to use
+    runner : RunnerType
+        The runner to use
+    database : DatabaseType
+        The database to use
+    cost_model : CostModelType
+        The cost model to use
+    measure_callbacks : CallbackListType
+        The measure callbacks to use
+    task_scheduler : TaskSchedulerType
+        The task scheduler to use
+    space : SpaceGeneratorType
+        The space generator to use
+    strategy : SearchStrategyType
+        The search strategy to use
+    seed : Optional[int]
+        The random seed
+
+    Returns
+    -------
+    database : Database
+        The database that contains the tuning records
+    """
+    tasks, task_weights = extracted_tasks_to_tune_contexts(
+        extracted_tasks=extract_tasks(mod, target, params),

Review Comment:
   let me try to add an `executor` parameter to this interface



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to