zxybazh commented on code in PR #10986:
URL: https://github.com/apache/tvm/pull/10986#discussion_r851891047
##########
python/tvm/meta_schedule/tune.py:
##########
@@ -449,95 +336,189 @@ def _mutator_probs(
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
- @staticmethod
- def _tune_context(
- tune_context: Optional[TuneContext],
- mod: IRModule,
- target: Target,
- config: SearchStrategyConfig,
- task_name: str,
- space_generator: Optional[FnSpaceGenerator],
- sch_rules: Optional[FnScheduleRule],
- postprocs: Optional[FnPostproc],
- mutator_probs: Optional[FnMutatorProb],
- num_threads: Optional[int],
- ) -> TuneContext:
- if tune_context is None:
- return TuneContext(
- mod=mod,
- target=target,
- # pylint: disable=protected-access
- space_generator=Parse._space_generator(space_generator),
- search_strategy=config.create_strategy(),
- sch_rules=Parse._sch_rules(sch_rules, target),
- postprocs=Parse._postproc(postprocs, target),
- mutator_probs=Parse._mutator_probs(mutator_probs, target),
- # pylint: enable=protected-access
- task_name=task_name,
- rand_state=-1,
- num_threads=num_threads,
- )
- if not isinstance(tune_context, TuneContext):
- raise TypeError(f"Expected `tune_context` to be TuneContext, but
gets: {tune_context}")
- return tune_context
- @staticmethod
- def _task_scheduler(
- task_scheduler: Union[None, TaskScheduler, FnTaskScheduler],
- tasks: List[TuneContext],
- task_weights: List[float],
- builder: Builder,
- runner: Runner,
- database: Database,
- max_trials: int,
- cost_model: CostModel,
- measure_callbacks: List[MeasureCallback],
- ):
- if task_scheduler is None:
- return GradientBased(
- tasks=tasks,
- task_weights=task_weights,
- builder=builder,
- runner=runner,
- database=database,
- max_trials=max_trials,
- cost_model=cost_model,
- measure_callbacks=measure_callbacks,
+class TuneConfig(NamedTuple):
+ """Configuration for tuning
+
+ Parameters
+ ----------
+ max_trials_global: int
+ Maximum number of trials to run.
+ num_trials_per_iter: int
+ Number of trials to run per iteration.
+ max_trials_per_task: int
+ Maximum number of trials to run per task.
+ task_scheduler: str
+ Task scheduler to use.
+ Valid options are: round_robin, gradient.
+ search_strategy: str
+ Search strategy to use.
+ Valid options are: evolutionary, replay_func, replay_trace.
+ task_scheduler_config: Dict[str, Any]
+ Configuration for task scheduler.
+ search_strategy_config: Dict[str, Any]
+ Configuration for search strategy.
+ """
+
+ max_trials_global: int
+ num_trials_per_iter: int
+ max_trials_per_task: Optional[int] = None
+ task_scheduler: str = "gradient"
+ strategy: str = "evolutionary"
+ task_scheduler_config: Dict[str, Any] = {}
+ search_strategy_config: Dict[str, Any] = {}
+
Review Comment:
Sure, I'll add that to my backlog.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]