comaniac commented on a change in pull request #7376:
URL: https://github.com/apache/tvm/pull/7376#discussion_r571141012
##########
File path: python/tvm/auto_scheduler/dispatcher.py
##########
@@ -301,6 +306,85 @@ def update(self, target, workload_key, state):
entry[workload_args] = (state, 1)
+class ApplyHistoryBestOrSample(ApplyHistoryBest):
+ """
+ Apply the history best config, or sample a valid schedule if no config is
found.
+
+ Parameters
+ ----------
+ records : str or iterator of (auto_scheduler.measure.MeasureInput,\
+ auto_scheduler.measure.MeasureResult)
+ Collection of tuning records.
+ If is str, then it should be the filename of a records log file.
+ Each row of this file is an encoded record pair. Otherwise, it is an
iterator.
+ sample_simple_workloads: bool
+ When False, sampling will not apply to simple workloads (w/o
reduction).
+ cost_model_file: str
+ The filename of the pre-trained XGBoost cost model. If not present,
then random
+ model will be used.
+ """
+
+ def __init__(self, records, sample_simple_workloads=False,
cost_model_file=None):
+ self.sample_simple_workloads = sample_simple_workloads
+ self.log_dir = tempdir()
+ if cost_model_file is None:
+ self.cost_model = RandomModel()
+ else:
+ self.cost_model = XGBModel(num_warmup_sample=1,
model_file=cost_model_file)
Review comment:
Ah I see. It does confuse and we should improve it later.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]