This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0fdc0eab51 [MetaSchedule] Distributed Measurement (#11683)
0fdc0eab51 is described below

commit 0fdc0eab5199d1b6549d2b2f94c83d86d5545e81
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Fri Jun 17 11:55:39 2022 -0700

    [MetaSchedule] Distributed Measurement (#11683)
    
    This PR includes the distributed measurement of tuning candidates using 
builder and async runner, as well as some auxiliary functions. It enables 
multiple builders and multiple runners with a tracker connecting in between. 
The hierarchy of files in the database can be further compacted to make the 
database more concise.
---
 include/tvm/meta_schedule/database.h               |  27 +++
 python/tvm/meta_schedule/database/database.py      |  34 ++++
 .../tvm/meta_schedule/database/memory_database.py  |   3 +
 .../testing/dataset_sample_candidates.py           |  23 +--
 .../testing/distributed_measure_candidates.py      | 198 +++++++++++++++++++++
 python/tvm/meta_schedule/tune_context.py           |  44 +++++
 src/meta_schedule/database/database.cc             |  22 ++-
 src/meta_schedule/database/json_database.cc        |   9 +
 src/meta_schedule/tune_context.cc                  |  14 +-
 9 files changed, 361 insertions(+), 13 deletions(-)

diff --git a/include/tvm/meta_schedule/database.h 
b/include/tvm/meta_schedule/database.h
index 37a315bf74..b22d8beddb 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -98,6 +98,9 @@ struct WorkloadEqual {
   }
 };
 
+/*! \brief The class of measure candidates. */
+class MeasureCandidate;
+
 /*! \brief The class of tuning records. */
 class TuningRecordNode : public runtime::Object {
  public:
@@ -123,6 +126,9 @@ class TuningRecordNode : public runtime::Object {
   static constexpr const char* _type_key = "meta_schedule.TuningRecord";
   TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object);
 
+  /*! \brief Construct the measure candidate given the initial IR module and 
trace
+   * stored in the tuning record. */
+  MeasureCandidate AsMeasureCandidate() const;
   /*!
    * \brief Export the tuning record to a JSON string.
    * \return An array containing the trace, running secs, serialized target, 
and
@@ -187,6 +193,11 @@ class DatabaseNode : public runtime::Object {
    * \return An array of top K tuning records for the given workload.
    */
   virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
+  /*!
+   * \brief Get all tuning records from the database.
+   * \return An Array of all the tuning records in the database.
+   */
+  virtual Array<TuningRecord> GetAllTuningRecords() = 0;
   /*!
    * \brief Get the size of the database.
    * \return The size of the database.
@@ -224,6 +235,11 @@ class PyDatabaseNode : public DatabaseNode {
    * \return An array of top K tuning records for the given workload.
    */
   using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const 
Workload&, int)>;
+  /*!
+   * \brief The function type of `GetAllTuningRecords` method.
+   * \return An Array of all the tuning records in the database.
+   */
+  using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
   /*!
    * \brief The function type of `Size` method.
    * \return The size of the database.
@@ -238,6 +254,8 @@ class PyDatabaseNode : public DatabaseNode {
   FCommitTuningRecord f_commit_tuning_record;
   /*! \brief The packed function to the `GetTopK` function. */
   FGetTopK f_get_top_k;
+  /*! \brief The packed function to the `GetAllTuningRecords` function. */
+  FGetAllTuningRecords f_get_all_tuning_records;
   /*! \brief The packed function to the `Size` function. */
   FSize f_size;
 
@@ -249,6 +267,7 @@ class PyDatabaseNode : public DatabaseNode {
     // `f_commit_workload` is not visited
     // `f_commit_tuning_record` is not visited
     // `f_get_top_k` is not visited
+    // `f_get_all_tuning_records` is not visited
     // `f_size` is not visited
   }
 
@@ -273,6 +292,12 @@ class PyDatabaseNode : public DatabaseNode {
     return f_get_top_k(workload, top_k);
   }
 
+  Array<TuningRecord> GetAllTuningRecords() final {
+    ICHECK(f_get_all_tuning_records != nullptr)
+        << "PyDatabase's GetAllTuningRecords method not implemented!";
+    return f_get_all_tuning_records();
+  }
+
   int64_t Size() final {
     ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
     return f_size();
@@ -302,6 +327,7 @@ class Database : public runtime::ObjectRef {
    * \param f_commit_workload The packed function of `CommitWorkload`.
    * \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
    * \param f_get_top_k The packed function of `GetTopK`.
+   * \param f_get_all_tuning_records The packed function of 
`GetAllTuningRecords`.
    * \param f_size The packed function of `Size`.
    * \return The created database.
    */
@@ -309,6 +335,7 @@ class Database : public runtime::ObjectRef {
                                      PyDatabaseNode::FCommitWorkload 
f_commit_workload,
                                      PyDatabaseNode::FCommitTuningRecord 
f_commit_tuning_record,
                                      PyDatabaseNode::FGetTopK f_get_top_k,
+                                     PyDatabaseNode::FGetAllTuningRecords 
f_get_all_tuning_records,
                                      PyDatabaseNode::FSize f_size);
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, 
runtime::ObjectRef, DatabaseNode);
 };
diff --git a/python/tvm/meta_schedule/database/database.py 
b/python/tvm/meta_schedule/database/database.py
index 802a739e69..0c11f77591 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -115,6 +115,17 @@ class TuningRecord(Object):
             args_info,
         )
 
+    def as_measure_candidate(self) -> Any:
+        """Generate a measure candidate given an initial IR module and a trace
+        stored in the tuning record.
+
+        Returns
+        -------
+        candidate : MeasureCandidate
+            A generated candidate.
+        """
+        return _ffi_api.TuningRecordAsMeasureCandidate(self)  # type: ignore # 
pylint: disable=no-member
+
     def as_json(self) -> Any:
         """Export the tuning record to a JSON string.
 
@@ -203,6 +214,16 @@ class Database(Object):
         """
         return _ffi_api.DatabaseGetTopK(self, workload, top_k)  # type: ignore 
# pylint: disable=no-member
 
+    def get_all_tuning_records(self) -> List[TuningRecord]:
+        """Get all the tuning records from the database.
+
+        Returns
+        -------
+        tuning_records : List[TuningRecord]
+            All tuning records from the database.
+        """
+        return _ffi_api.DatabaseGetAllTuningRecords(self)  # type: ignore # 
pylint: disable=no-member
+
     def __len__(self) -> int:
         """Get the number of records in the database.
 
@@ -229,6 +250,7 @@ class _PyDatabase(Database):
         f_commit_workload: Callable = None,
         f_commit_tuning_record: Callable = None,
         f_get_top_k: Callable = None,
+        f_get_all_tuning_records: Callable = None,
         f_size: Callable = None,
     ):
         """Constructor."""
@@ -239,6 +261,7 @@ class _PyDatabase(Database):
             f_commit_workload,
             f_commit_tuning_record,
             f_get_top_k,
+            f_get_all_tuning_records,
             f_size,
         )
 
@@ -258,6 +281,7 @@ class PyDatabase:
             "commit_workload",
             "commit_tuning_record",
             "get_top_k",
+            "get_all_tuning_records",
             "__len__",
         ],
     }
@@ -317,6 +341,16 @@ class PyDatabase:
         """
         raise NotImplementedError
 
+    def get_all_tuning_records(self) -> List[TuningRecord]:
+        """Get all the tuning records from the database.
+
+        Returns
+        -------
+        tuning_records : List[TuningRecord]
+            All tuning records from the database.
+        """
+        raise NotImplementedError
+
     def __len__(self) -> int:
         """Get the number of records in the database.
 
diff --git a/python/tvm/meta_schedule/database/memory_database.py 
b/python/tvm/meta_schedule/database/memory_database.py
index 6d10e4b527..95d937cc77 100644
--- a/python/tvm/meta_schedule/database/memory_database.py
+++ b/python/tvm/meta_schedule/database/memory_database.py
@@ -56,6 +56,9 @@ class MemoryDatabase(PyDatabase):
             )
         )[: int(top_k)]
 
+    def get_all_tuning_records(self) -> List[TuningRecord]:
+        return self.records
+
     def __len__(self) -> int:
         return len(self.records)
 
diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py 
b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
index c80d78173e..35b872e735 100644
--- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
+++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py
@@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
     -------
     None
     """
+    candidate_path = os.path.join(
+        args.candidate_cache_dir, model_name, task_name + "_candidates.json"
+    )
+    workload_path = os.path.join(args.candidate_cache_dir, model_name, 
task_name + "_workload.json")
+    database = ms.database.JSONDatabase(
+        path_workload=workload_path,
+        path_tuning_record=candidate_path,
+    )
     sample_init_population = tvm.get_global_func(
         "meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
     )
@@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
     context.initialize()
     context.pre_tuning(
         context.generate_design_space(),
-        database=ms.database.MemoryDatabase(),  # type: ignore
+        database=database,
         cost_model=ms.cost_model.RandomModel(),  # type: ignore
     )
 
@@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
     all_states = all_states[: args.num_samples_per_task]
 
     workload = ms.database.Workload(context.mod)
-    file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + 
".json")
-    with open(file_path, "w", encoding="utf8") as file:
-        for i, state in enumerate(all_states):
-            tuning_record = ms.database.TuningRecord(state.trace, workload)
-            json_str = json.dumps(tuning_record.as_json())
-            assert "\n" not in json_str, "Failed to generate single line 
string."
-            if i == len(all_states) - 1:
-                file.write(json_str)
-            else:
-                file.write(json_str + "\n")
+    database.commit_workload(context.mod)
+    for state in all_states:
+        database.commit_tuning_record(ms.database.TuningRecord(state.trace, 
workload))
 
 
 args = _parse_args()  # pylint: disable=invalid-name
diff --git a/python/tvm/meta_schedule/testing/distributed_measure_candidates.py 
b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py
new file mode 100644
index 0000000000..8e646c4846
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import argparse
+import glob
+import os
+
+from tqdm import tqdm  # type: ignore
+from tvm import meta_schedule as ms
+from tvm.target import Target
+
+
+def _parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--candidate_cache_dir", type=str, help="Please provide the full path 
to the candidates."
+    )
+    parser.add_argument(
+        "--result_cache_dir", type=str, help="Please provide the full path to 
the result database."
+    )
+    parser.add_argument(
+        "--target",
+        type=str,
+        default="nvidia/nvidia-v100",
+        help="Please specify the target hardware for tuning context.",
+    )
+    parser.add_argument(
+        "--rpc_host", type=str, help="Please provide the private IPv4 address 
for the tracker."
+    )
+    parser.add_argument(
+        "--rpc_port", type=int, default=4445, help="Please provide the port 
for the tracker."
+    )
+    parser.add_argument(
+        "--rpc_key",
+        type=str,
+        default="p3.2xlarge",
+        help="Please provide the key for the rpc servers.",
+    )
+    parser.add_argument(
+        "--builder_timeout_sec",
+        type=int,
+        default=10,
+        help="The time for the builder session to time out.",
+    )
+    parser.add_argument(
+        "--min_repeat_ms", type=int, default=100, help="The time for 
preheating the gpu."
+    )
+    parser.add_argument(
+        "--runner_timeout_sec",
+        type=int,
+        default=100,
+        help="The time for the runner session to time out.",
+    )
+    parser.add_argument(
+        "--cpu_flush", type=bool, default=False, help="Whether to enable cpu 
cache flush or not."
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=128,
+        help="The batch size of candidates sent to builder and runner each 
time.",
+    )
+    return parser.parse_args()
+
+
+# pylint: disable=too-many-locals
+def measure_candidates(database, builder, runner):
+    """Send the candidates to builder and runner for distributed measurement,
+    and save the results in a new json database.
+
+    Parameters
+    ----------
+    database : JSONDatabase
+        The database for candidates to be measured.
+    builder : Builder
+        The builder for building the candidates.
+    runner : Runner
+        The runner for measuring the candidates.
+
+    Returns
+    -------
+    None
+    """
+    candidates, runner_results, build_fail_indices, run_fail_indices = [], [], 
[], []
+    context = ms.TuneContext(target=Target(args.target))
+    tuning_records = database.get_all_tuning_records()
+    for record in tuning_records:
+        candidates.append(record.as_measure_candidate())
+    with ms.Profiler() as profiler:
+        for idx in range(0, len(candidates), args.batch_size):
+            batch_candidates = candidates[idx : idx + args.batch_size]
+            context._set_measure_candidates(batch_candidates)  # pylint: 
disable=protected-access
+            with ms.Profiler.timeit("build"):
+                context._send_to_builder(builder)  # pylint: 
disable=protected-access
+            with ms.Profiler.timeit("run"):
+                context._send_to_runner(runner)  # pylint: 
disable=protected-access
+                batch_runner_results = context._join()  # pylint: 
disable=protected-access
+            runner_results.extend(batch_runner_results)
+            for i, result in enumerate(context.builder_results):
+                if result.error_msg is None:
+                    ms.utils.remove_build_dir(result.artifact_path)
+                else:
+                    build_fail_indices.append(i + idx)
+            context._clear_measure_state()  # pylint: disable=protected-access
+
+    model_name, workload_name = database.path_workload.split("/")[-2:]
+    record_name = database.path_tuning_record.split("/")[-1]
+    new_database = ms.database.JSONDatabase(
+        path_workload=os.path.join(args.result_cache_dir, model_name, 
workload_name),
+        path_tuning_record=os.path.join(args.result_cache_dir, model_name, 
record_name),
+    )
+    workload = tuning_records[0].workload
+    new_database.commit_workload(workload.mod)
+    for i, (record, result) in enumerate(zip(tuning_records, runner_results)):
+        if result.error_msg is None:
+            new_database.commit_tuning_record(
+                ms.database.TuningRecord(
+                    trace=record.trace,
+                    workload=workload,
+                    run_secs=[v.value for v in result.run_secs],
+                    target=Target(args.target),
+                )
+            )
+        else:
+            run_fail_indices.append(i)
+    fail_indices_name = workload_name.replace("_workload.json", 
"_failed_indices.txt")
+    with open(
+        os.path.join(args.result_cache_dir, model_name, fail_indices_name), 
"w", encoding="utf8"
+    ) as file:
+        file.write(" ".join([str(n) for n in run_fail_indices]))
+    print(
+        f"Builder time: {profiler.get()['build']}, Runner time: 
{profiler.get()['run']}\n\
+            Failed number of builds: {len(build_fail_indices)},\
+            Failed number of runs: {len(run_fail_indices)}"
+    )
+
+
+args = _parse_args()  # pylint: disable=invalid-name
+
+
+def main():
+    builder = ms.builder.LocalBuilder(timeout_sec=args.builder_timeout_sec)
+    runner = ms.runner.RPCRunner(
+        rpc_config=ms.runner.RPCConfig(
+            tracker_host=args.rpc_host,
+            tracker_port=args.rpc_port,
+            tracker_key=args.rpc_key,
+            session_timeout_sec=args.runner_timeout_sec,
+        ),
+        evaluator_config=ms.runner.EvaluatorConfig(
+            number=3,
+            repeat=1,
+            min_repeat_ms=args.min_repeat_ms,
+            enable_cpu_cache_flush=args.cpu_flush,
+        ),
+        max_workers=os.cpu_count(),
+    )
+    if not os.path.isdir(args.candidate_cache_dir):
+        raise Exception("Please provide a correct candidate cache dir.")
+    try:
+        os.makedirs(args.result_cache_dir, exist_ok=True)
+    except OSError:
+        print(f"Directory {args.result_cache_dir} cannot be created 
successfully.")
+    model_dirs = glob.glob(os.path.join(args.candidate_cache_dir, "*"))
+    for model_dir in model_dirs:
+        model_name = model_dir.split("/")[-1]
+        os.makedirs(os.path.join(args.result_cache_dir, model_name), 
exist_ok=True)
+        all_tasks = glob.glob(os.path.join(model_dir, "*.json"))
+        workload_paths = []
+        for path in all_tasks:
+            if path.endswith("_workload.json"):
+                workload_paths.append(path)
+        for workload_path in tqdm(workload_paths):
+            candidate_path = workload_path.replace("_workload.json", 
"_candidates.json")
+            database = ms.database.JSONDatabase(
+                path_workload=workload_path,
+                path_tuning_record=candidate_path,
+            )
+            measure_candidates(database, builder, runner)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/python/tvm/meta_schedule/tune_context.py 
b/python/tvm/meta_schedule/tune_context.py
index b7975e7b2c..30c726ded2 100644
--- a/python/tvm/meta_schedule/tune_context.py
+++ b/python/tvm/meta_schedule/tune_context.py
@@ -171,6 +171,50 @@ class TuneContext(Object):
         )
         _ffi_api.TuneContextInitialize(self)  # type: ignore # pylint: 
disable=no-member
 
+    def _set_measure_candidates(self, candidates):
+        """Set candidates in a tuning context.
+
+        Parameters
+        ----------
+        candidates : List[MeasureCandidate]
+            A list of measure candidates for the tuning context.
+        """
+        _ffi_api.TuneContextSetMeasureCandidates(self, candidates)  # type: 
ignore # pylint: disable=no-member
+
+    def _send_to_builder(self, builder):
+        """Send candidates to builder.
+
+        Parameters
+        ----------
+        builder : Builder
+            The builder for building the candidates.
+        """
+        _ffi_api.TuneContextSendToBuilder(self, builder)  # type: ignore # 
pylint: disable=no-member
+
+    def _send_to_runner(self, runner):
+        """Send candidates to runner.
+
+        Parameters
+        ----------
+        runner : Runner
+            The runner for running the candidates.
+        """
+        _ffi_api.TuneContextSendToRunner(self, runner)  # type: ignore # 
pylint: disable=no-member
+
+    def _join(self):
+        """Join the runner processes.
+
+        Returns
+        -------
+        result : List[RunnerResult]
+            The runner results.
+        """
+        return _ffi_api.TuneContextJoin(self)  # type: ignore # pylint: 
disable=no-member
+
+    def _clear_measure_state(self):
+        """Clear the measure states."""
+        _ffi_api.TuneContextClearMeasureState(self)  # type: ignore # pylint: 
disable=no-member
+
     def generate_design_space(self) -> List[Schedule]:
         """Generate design spaces given a module.
 
diff --git a/src/meta_schedule/database/database.cc 
b/src/meta_schedule/database/database.cc
index 9905ff73c7..5adff49984 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -85,6 +85,19 @@ TuningRecord::TuningRecord(tir::Trace trace, Workload 
workload, Optional<Array<F
   this->data_ = n;
 }
 
+MeasureCandidate TuningRecordNode::AsMeasureCandidate() const {
+  tir::Schedule sch =
+      tir::Schedule::Traced(workload->mod, -1, 0, 
tir::ScheduleErrorRenderLevel::kDetail);
+  trace->ApplyToSchedule(sch, false, nullptr);
+  tir::PrimFunc func;
+  for (const auto& kv : sch->mod()->functions) {
+    func = Downcast<tir::PrimFunc>(kv.second);
+  }
+  Array<ArgInfo> args_info = ArgInfo::FromPrimFunc(func);
+  MeasureCandidate candidate = MeasureCandidate(sch, args_info);
+  return candidate;
+}
+
 ObjectRef TuningRecordNode::AsJSON() const {
   Optional<Array<ObjectRef>> json_args_info{nullptr};
   Optional<ObjectRef> json_target{nullptr};
@@ -152,12 +165,15 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& 
json_obj, const Workload& w
 Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
                               PyDatabaseNode::FCommitWorkload 
f_commit_workload,
                               PyDatabaseNode::FCommitTuningRecord 
f_commit_tuning_record,
-                              PyDatabaseNode::FGetTopK f_get_top_k, 
PyDatabaseNode::FSize f_size) {
+                              PyDatabaseNode::FGetTopK f_get_top_k,
+                              PyDatabaseNode::FGetAllTuningRecords 
f_get_all_tuning_records,
+                              PyDatabaseNode::FSize f_size) {
   ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
   n->f_has_workload = f_has_workload;
   n->f_commit_workload = f_commit_workload;
   n->f_commit_tuning_record = f_commit_tuning_record;
   n->f_get_top_k = f_get_top_k;
+  n->f_get_all_tuning_records = f_get_all_tuning_records;
   n->f_size = f_size;
   return Database(n);
 }
@@ -179,6 +195,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
                        Optional<Target> target, Optional<Array<ArgInfo>> 
args_info) {
       return TuningRecord(trace, workload, run_secs, target, args_info);
     });
+TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate")
+    .set_body_method<TuningRecord>(&TuningRecordNode::AsMeasureCandidate);
 TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
     .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
 
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
@@ -190,6 +208,8 @@ 
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
     .set_body_method<Database>(&DatabaseNode::CommitTuningRecord);
 TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
     .set_body_method<Database>(&DatabaseNode::GetTopK);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords")
+    .set_body_method<Database>(&DatabaseNode::GetAllTuningRecords);
 
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
 
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);
 
diff --git a/src/meta_schedule/database/json_database.cc 
b/src/meta_schedule/database/json_database.cc
index 4f5bd9b136..9bb7ee1027 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -156,6 +156,15 @@ class JSONDatabaseNode : public DatabaseNode {
     return results;
   }
 
+  Array<TuningRecord> GetAllTuningRecords() {
+    Array<TuningRecord> results;
+    results.reserve(Size());
+    for (const TuningRecord& record : this->tuning_records_) {
+      results.push_back(record);
+    }
+    return results;
+  }
+
   int64_t Size() { return tuning_records_.size(); }
 };
 
diff --git a/src/meta_schedule/tune_context.cc 
b/src/meta_schedule/tune_context.cc
index 0c70dcf5c4..57b2344c6f 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -142,7 +142,9 @@ Array<RunnerResult> TuneContextNode::_Join() {
       results.push_back(future->Result());
     }
   }
-  
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
 results);
+  if (this->search_strategy.defined()) {
+    
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
 results);
+  }
   ICHECK(this->measure_candidates.defined());
   ICHECK(this->builder_results.defined());
   ICHECK_EQ(results.size(), this->measure_candidates.value().size());
@@ -177,6 +179,16 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
 TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex);
 TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize")
     .set_body_method<TuneContext>(&TuneContextNode::Initialize);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSetMeasureCandidates")
+    .set_body_method<TuneContext>(&TuneContextNode::_SetMeasureCandidates);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToBuilder")
+    .set_body_method<TuneContext>(&TuneContextNode::_SendToBuilder);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToRunner")
+    .set_body_method<TuneContext>(&TuneContextNode::_SendToRunner);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextJoin")
+    .set_body_method<TuneContext>(&TuneContextNode::_Join);
+TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClearMeasureState")
+    .set_body_method<TuneContext>(&TuneContextNode::_ClearMeasureState);
 
 }  // namespace meta_schedule
 }  // namespace tvm

Reply via email to