This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0fdc0eab51 [MetaSchedule] Distributed Measurement (#11683)
0fdc0eab51 is described below
commit 0fdc0eab5199d1b6549d2b2f94c83d86d5545e81
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Fri Jun 17 11:55:39 2022 -0700
[MetaSchedule] Distributed Measurement (#11683)
This PR includes the distributed measurement of tuning candidates using
builder and async runner, as well as some auxiliary functions. It enables
multiple builders and multiple runners with a tracker connecting in between.
The hierarchy of files in the database can be further compacted to make the
database more concise.
---
include/tvm/meta_schedule/database.h | 27 +++
python/tvm/meta_schedule/database/database.py | 34 ++++
.../tvm/meta_schedule/database/memory_database.py | 3 +
.../testing/dataset_sample_candidates.py | 23 +--
.../testing/distributed_measure_candidates.py | 198 +++++++++++++++++++++
python/tvm/meta_schedule/tune_context.py | 44 +++++
src/meta_schedule/database/database.cc | 22 ++-
src/meta_schedule/database/json_database.cc | 9 +
src/meta_schedule/tune_context.cc | 14 +-
9 files changed, 361 insertions(+), 13 deletions(-)
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index 37a315bf74..b22d8beddb 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -98,6 +98,9 @@ struct WorkloadEqual {
}
};
+/*! \brief The class of measure candidates. */
+class MeasureCandidate;
+
/*! \brief The class of tuning records. */
class TuningRecordNode : public runtime::Object {
public:
@@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object {
static constexpr const char* _type_key = "meta_schedule.TuningRecord";
TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);
+ /*! \brief Construct the measure candidate given the initial IR module and
trace
+ * stored in the tuning record. */
+ MeasureCandidate AsMeasureCandidate() const;
/*!
* \brief Export the tuning record to a JSON string.
* \return An array containing the trace, running secs, serialized target,
and
@@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object {
* \return An array of top K tuning records for the given workload.
*/
virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
+ /*!
+ * \brief Get all tuning records from the database.
+ * \return An Array of all the tuning records in the database.
+ */
+ virtual Array<TuningRecord> GetAllTuningRecords() = 0;
/*!
* \brief Get the size of the database.
* \return The size of the database.
@@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode {
* \return An array of top K tuning records for the given workload.
*/
using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const
Workload&, int)>;
+ /*!
+ * \brief The function type of `GetAllTuningRecords` method.
+ * \return An Array of all the tuning records in the database.
+ */
+ using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
@@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode {
FCommitTuningRecord f_commit_tuning_record;
/*! \brief The packed function to the `GetTopK` function. */
FGetTopK f_get_top_k;
+ /*! \brief The packed function to the `GetAllTuningRecords` function. */
+ FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `Size` function. */
FSize f_size;
@@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
+ // `f_get_all_tuning_records` is not visited
// `f_size` is not visited
}
@@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_top_k(workload, top_k);
}
+ Array<TuningRecord> GetAllTuningRecords() final {
+ ICHECK(f_get_all_tuning_records != nullptr)
+ << "PyDatabase's GetAllTuningRecords method not implemented!";
+ return f_get_all_tuning_records();
+ }
+
int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
@@ -302,6 +327,7 @@ class Database : public runtime::ObjectRef {
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
+ * \param f_get_all_tuning_records The packed function of
`GetAllTuningRecords`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
@@ -309,6 +335,7 @@ class Database : public runtime::ObjectRef {
PyDatabaseNode::FCommitWorkload
f_commit_workload,
PyDatabaseNode::FCommitTuningRecord
f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
+ PyDatabaseNode::FGetAllTuningRecords
f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database,
runtime::ObjectRef, DatabaseNode);
};
diff --git a/python/tvm/meta_schedule/database/database.py
b/python/tvm/meta_schedule/database/database.py
index 802a739e69..0c11f77591 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -115,6 +115,17 @@ class TuningRecord(Object):
args_info,
)
+ def as_measure_candidate(self) -> Any:
+ """Generate a measure candidate given an initial IR module and a trace
+ stored in the tuning record.
+
+ Returns
+ -------
+ candidate : MeasureCandidate
+ A generated candidate.
+ """
+ return _ffi_api.TuningRecordAsMeasureCandidate(self) # type: ignore #
pylint: disable=no-member
+
def as_json(self) -> Any:
"""Export the tuning record to a JSON string.
@@ -203,6 +214,16 @@ class Database(Object):
"""
return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore
# pylint: disable=no-member
+ def get_all_tuning_records(self) -> List[TuningRecord]:
+ """Get all the tuning records from the database.
+
+ Returns
+ -------
+ tuning_records : List[TuningRecord]
+ All tuning records from the database.
+ """
+ return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore #
pylint: disable=no-member
+
def __len__(self) -> int:
"""Get the number of records in the database.
@@ -229,6 +250,7 @@ class _PyDatabase(Database):
f_commit_workload: Callable = None,
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
+ f_get_all_tuning_records: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
@@ -239,6 +261,7 @@ class _PyDatabase(Database):
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
+ f_get_all_tuning_records,
f_size,
)
@@ -258,6 +281,7 @@ class PyDatabase:
"commit_workload",
"commit_tuning_record",
"get_top_k",
+ "get_all_tuning_records",
"__len__",
],
}
@@ -317,6 +341,16 @@ class PyDatabase:
"""
raise NotImplementedError
+ def get_all_tuning_records(self) -> List[TuningRecord]:
+ """Get all the tuning records from the database.
+
+ Returns
+ -------
+ tuning_records : List[TuningRecord]
+ All tuning records from the database.
+ """
+ raise NotImplementedError
+
def __len__(self) -> int:
"""Get the number of records in the database.
diff --git a/python/tvm/meta_schedule/database/memory_database.py
b/python/tvm/meta_schedule/database/memory_database.py
index 6d10e4b527..95d937cc77 100644
--- a/python/tvm/meta_schedule/database/memory_database.py
+++ b/python/tvm/meta_schedule/database/memory_database.py
@@ -56,6 +56,9 @@ class MemoryDatabase(PyDatabase):
)
)[: int(top_k)]
+ def get_all_tuning_records(self) -> List[TuningRecord]:
+ return self.records
+
def __len__(self) -> int:
return len(self.records)
diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
index c80d78173e..35b872e735 100644
--- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
+++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
@@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
-------
None
"""
+ candidate_path = os.path.join(
+ args.candidate_cache_dir, model_name, task_name + "_candidates.json"
+ )
+ workload_path = os.path.join(args.candidate_cache_dir, model_name,
task_name + "_workload.json")
+ database = ms.database.JSONDatabase(
+ path_workload=workload_path,
+ path_tuning_record=candidate_path,
+ )
sample_init_population = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
)
@@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
- database=ms.database.MemoryDatabase(), # type: ignore
+ database=database,
cost_model=ms.cost_model.RandomModel(), # type: ignore
)
@@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
all_states = all_states[: args.num_samples_per_task]
workload = ms.database.Workload(context.mod)
- file_path = os.path.join(args.candidate_cache_dir, model_name, task_name +
".json")
- with open(file_path, "w", encoding="utf8") as file:
- for i, state in enumerate(all_states):
- tuning_record = ms.database.TuningRecord(state.trace, workload)
- json_str = json.dumps(tuning_record.as_json())
- assert "\n" not in json_str, "Failed to generate single line
string."
- if i == len(all_states) - 1:
- file.write(json_str)
- else:
- file.write(json_str + "\n")
+ database.commit_workload(context.mod)
+ for state in all_states:
+ database.commit_tuning_record(ms.database.TuningRecord(state.trace,
workload))
args = _parse_args() # pylint: disable=invalid-name
diff --git a/python/tvm/meta_schedule/testing/distributed_measure_candidates.py
b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py
new file mode 100644
index 0000000000..8e646c4846
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import argparse
+import glob
+import os
+
+from tqdm import tqdm # type: ignore
+from tvm import meta_schedule as ms
+from tvm.target import Target
+
+
+def _parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--candidate_cache_dir", type=str, help="Please provide the full path
to the candidates."
+ )
+ parser.add_argument(
+ "--result_cache_dir", type=str, help="Please provide the full path to
the result database."
+ )
+ parser.add_argument(
+ "--target",
+ type=str,
+ default="nvidia/nvidia-v100",
+ help="Please specify the target hardware for tuning context.",
+ )
+ parser.add_argument(
+ "--rpc_host", type=str, help="Please provide the private IPv4 address
for the tracker."
+ )
+ parser.add_argument(
+ "--rpc_port", type=int, default=4445, help="Please provide the port
for the tracker."
+ )
+ parser.add_argument(
+ "--rpc_key",
+ type=str,
+ default="p3.2xlarge",
+ help="Please provide the key for the rpc servers.",
+ )
+ parser.add_argument(
+ "--builder_timeout_sec",
+ type=int,
+ default=10,
+ help="The time for the builder session to time out.",
+ )
+ parser.add_argument(
+ "--min_repeat_ms", type=int, default=100, help="The time for
preheating the gpu."
+ )
+ parser.add_argument(
+ "--runner_timeout_sec",
+ type=int,
+ default=100,
+ help="The time for the runner session to time out.",
+ )
+ parser.add_argument(
+ "--cpu_flush", type=bool, default=False, help="Whether to enable cpu
cache flush or not."
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=128,
+ help="The batch size of candidates sent to builder and runner each
time.",
+ )
+ return parser.parse_args()
+
+
+# pylint: disable=too-many-locals
+def measure_candidates(database, builder, runner):
+ """Send the candidates to builder and runner for distributed measurement,
+ and save the results in a new json database.
+
+ Parameters
+ ----------
+ database : JSONDatabase
+ The database for candidates to be measured.
+ builder : Builder
+ The builder for building the candidates.
+ runner : Runner
+ The runner for measuring the candidates.
+
+ Returns
+ -------
+ None
+ """
+ candidates, runner_results, build_fail_indices, run_fail_indices = [], [],
[], []
+ context = ms.TuneContext(target=Target(args.target))
+ tuning_records = database.get_all_tuning_records()
+ for record in tuning_records:
+ candidates.append(record.as_measure_candidate())
+ with ms.Profiler() as profiler:
+ for idx in range(0, len(candidates), args.batch_size):
+ batch_candidates = candidates[idx : idx + args.batch_size]
+ context._set_measure_candidates(batch_candidates) # pylint:
disable=protected-access
+ with ms.Profiler.timeit("build"):
+ context._send_to_builder(builder) # pylint:
disable=protected-access
+ with ms.Profiler.timeit("run"):
+ context._send_to_runner(runner) # pylint:
disable=protected-access
+ batch_runner_results = context._join() # pylint:
disable=protected-access
+ runner_results.extend(batch_runner_results)
+ for i, result in enumerate(context.builder_results):
+ if result.error_msg is None:
+ ms.utils.remove_build_dir(result.artifact_path)
+ else:
+ build_fail_indices.append(i + idx)
+ context._clear_measure_state() # pylint: disable=protected-access
+
+ model_name, workload_name = database.path_workload.split("/")[-2:]
+ record_name = database.path_tuning_record.split("/")[-1]
+ new_database = ms.database.JSONDatabase(
+ path_workload=os.path.join(args.result_cache_dir, model_name,
workload_name),
+ path_tuning_record=os.path.join(args.result_cache_dir, model_name,
record_name),
+ )
+ workload = tuning_records[0].workload
+ new_database.commit_workload(workload.mod)
+ for i, (record, result) in enumerate(zip(tuning_records, runner_results)):
+ if result.error_msg is None:
+ new_database.commit_tuning_record(
+ ms.database.TuningRecord(
+ trace=record.trace,
+ workload=workload,
+ run_secs=[v.value for v in result.run_secs],
+ target=Target(args.target),
+ )
+ )
+ else:
+ run_fail_indices.append(i)
+ fail_indices_name = workload_name.replace("_workload.json",
"_failed_indices.txt")
+ with open(
+ os.path.join(args.result_cache_dir, model_name, fail_indices_name),
"w", encoding="utf8"
+ ) as file:
+ file.write(" ".join([str(n) for n in run_fail_indices]))
+ print(
+ f"Builder time: {profiler.get()['build']}, Runner time:
{profiler.get()['run']}\n\
+ Failed number of builds: {len(build_fail_indices)},\
+ Failed number of runs: {len(run_fail_indices)}"
+ )
+
+
+args = _parse_args() # pylint: disable=invalid-name
+
+
+def main():
+ builder = ms.builder.LocalBuilder(timeout_sec=args.builder_timeout_sec)
+ runner = ms.runner.RPCRunner(
+ rpc_config=ms.runner.RPCConfig(
+ tracker_host=args.rpc_host,
+ tracker_port=args.rpc_port,
+ tracker_key=args.rpc_key,
+ session_timeout_sec=args.runner_timeout_sec,
+ ),
+ evaluator_config=ms.runner.EvaluatorConfig(
+ number=3,
+ repeat=1,
+ min_repeat_ms=args.min_repeat_ms,
+ enable_cpu_cache_flush=args.cpu_flush,
+ ),
+ max_workers=os.cpu_count(),
+ )
+ if not os.path.isdir(args.candidate_cache_dir):
+ raise Exception("Please provide a correct candidate cache dir.")
+ try:
+ os.makedirs(args.result_cache_dir, exist_ok=True)
+ except OSError:
+ print(f"Directory {args.result_cache_dir} cannot be created
successfully.")
+ model_dirs = glob.glob(os.path.join(args.candidate_cache_dir, "*"))
+ for model_dir in model_dirs:
+ model_name = model_dir.split("/")[-1]
+ os.makedirs(os.path.join(args.result_cache_dir, model_name),
exist_ok=True)
+ all_tasks = glob.glob(os.path.join(model_dir, "*.json"))
+ workload_paths = []
+ for path in all_tasks:
+ if path.endswith("_workload.json"):
+ workload_paths.append(path)
+ for workload_path in tqdm(workload_paths):
+ candidate_path = workload_path.replace("_workload.json",
"_candidates.json")
+ database = ms.database.JSONDatabase(
+ path_workload=workload_path,
+ path_tuning_record=candidate_path,
+ )
+ measure_candidates(database, builder, runner)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/tvm/meta_schedule/tune_context.py
b/python/tvm/meta_schedule/tune_context.py
index b7975e7b2c..30c726ded2 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -171,6 +171,50 @@ class TuneContext(Object):
)
_ffi_api.TuneContextInitialize(self) # type: ignore # pylint:
disable=no-member
+ def _set_measure_candidates(self, candidates):
+ """Set candidates in a tuning context.
+
+ Parameters
+ ----------
+ candidates : List[MeasureCandidate]
+ A list of measure candidates for the tuning context.
+ """
+ _ffi_api.TuneContextSetMeasureCandidates(self, candidates) # type:
ignore # pylint: disable=no-member
+
+ def _send_to_builder(self, builder):
+ """Send candidates to builder.
+
+ Parameters
+ ----------
+ builder : Builder
+ The builder for building the candidates.
+ """
+ _ffi_api.TuneContextSendToBuilder(self, builder) # type: ignore #
pylint: disable=no-member
+
+ def _send_to_runner(self, runner):
+ """Send candidates to runner.
+
+ Parameters
+ ----------
+ runner : Runner
+ The runner for running the candidates.
+ """
+ _ffi_api.TuneContextSendToRunner(self, runner) # type: ignore #
pylint: disable=no-member
+
+ def _join(self):
+ """Join the runner processes.
+
+ Returns
+ -------
+ result : List[RunnerResult]
+ The runner results.
+ """
+ return _ffi_api.TuneContextJoin(self) # type: ignore # pylint:
disable=no-member
+
+ def _clear_measure_state(self):
+ """Clear the measure states."""
+ _ffi_api.TuneContextClearMeasureState(self) # type: ignore # pylint:
disable=no-member
+
def generate_design_space(self) -> List[Schedule]:
"""Generate design spaces given a module.
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index 9905ff73c7..5adff49984 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -85,6 +85,19 @@ TuningRecord::TuningRecord(tir::Trace trace, Workload
workload, Optional<Array<F
this->data_ = n;
}
+MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
+ tir::Schedule sch =
+ tir::Schedule::Traced(workload->mod, -1, 0,
tir::ScheduleErrorRenderLevel::kDetail);
+ trace->ApplyToSchedule(sch, false, nullptr);
+ tir::PrimFunc func;
+ for (const auto& kv : sch->mod()->functions) {
+ func = Downcast<tir::PrimFunc>(kv.second);
+ }
+ Array<ArgInfo> args_info = ArgInfo::FromPrimFunc(func);
+ MeasureCandidate candidate = MeasureCandidate(sch, args_info);
+ return candidate;
+}
+
ObjectRef TuningRecordNode::AsJSON() const {
Optional<Array<ObjectRef>> json_args_info{nullptr};
Optional<ObjectRef> json_target{nullptr};
@@ -152,12 +165,15 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef&
json_obj, const Workload& w
Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload
f_commit_workload,
PyDatabaseNode::FCommitTuningRecord
f_commit_tuning_record,
- PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FSize f_size) {
+ PyDatabaseNode::FGetTopK f_get_top_k,
+ PyDatabaseNode::FGetAllTuningRecords
f_get_all_tuning_records,
+ PyDatabaseNode::FSize f_size) {
ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
n->f_has_workload = f_has_workload;
n->f_commit_workload = f_commit_workload;
n->f_commit_tuning_record = f_commit_tuning_record;
n->f_get_top_k = f_get_top_k;
+ n->f_get_all_tuning_records = f_get_all_tuning_records;
n->f_size = f_size;
return Database(n);
}
@@ -179,6 +195,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
Optional<Target> target, Optional<Array<ArgInfo>>
args_info) {
return TuningRecord(trace, workload, run_secs, target, args_info);
});
+TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate")
+ .set_body_method<TuningRecord>(&TuningRecordNode::AsMeasureCandidate);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
.set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
@@ -190,6 +208,8 @@
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
.set_body_method<Database>(&DatabaseNode::CommitTuningRecord);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
.set_body_method<Database>(&DatabaseNode::GetTopK);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords")
+ .set_body_method<Database>(&DatabaseNode::GetAllTuningRecords);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index 4f5bd9b136..9bb7ee1027 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -156,6 +156,15 @@ class JSONDatabaseNode : public DatabaseNode {
return results;
}
+ Array<TuningRecord> GetAllTuningRecords() {
+ Array<TuningRecord> results;
+ results.reserve(Size());
+ for (const TuningRecord& record : this->tuning_records_) {
+ results.push_back(record);
+ }
+ return results;
+ }
+
int64_t Size() { return tuning_records_.size(); }
};
diff --git a/src/meta_schedule/tune_context.cc
b/src/meta_schedule/tune_context.cc
index 0c70dcf5c4..57b2344c6f 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -142,7 +142,9 @@ Array<RunnerResult> TuneContextNode::_Join() {
results.push_back(future->Result());
}
}
-
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
results);
+ if (this->search_strategy.defined()) {
+
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
results);
+ }
ICHECK(this->measure_candidates.defined());
ICHECK(this->builder_results.defined());
ICHECK_EQ(results.size(), this->measure_candidates.value().size());
@@ -177,6 +179,16 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex);
TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize")
.set_body_method<TuneContext>(&TuneContextNode::Initialize);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSetMeasureCandidates")
+ .set_body_method<TuneContext>(&TuneContextNode::_SetMeasureCandidates);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToBuilder")
+ .set_body_method<TuneContext>(&TuneContextNode::_SendToBuilder);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToRunner")
+ .set_body_method<TuneContext>(&TuneContextNode::_SendToRunner);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextJoin")
+ .set_body_method<TuneContext>(&TuneContextNode::_Join);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClearMeasureState")
+ .set_body_method<TuneContext>(&TuneContextNode::_ClearMeasureState);
} // namespace meta_schedule
} // namespace tvm