This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 57a13a2324 [Metaschedule] Aligning get_top_k logic in MemoryDatabase
and JSONDatabase (#13611)
57a13a2324 is described below
commit 57a13a23243a4e70ea67a2c54f402ff7e9017d8a
Author: Alexey Voronov <[email protected]>
AuthorDate: Thu Dec 15 04:51:54 2022 +0300
[Metaschedule] Aligning get_top_k logic in MemoryDatabase and JSONDatabase
(#13611)
[Metaschedule] Align get_top_k logic in MemoryDatabase and JSONDatabase
---
src/meta_schedule/database/json_database.cc | 10 ++++--
src/meta_schedule/database/memory_database.cc | 12 +++++--
.../python/unittest/test_meta_schedule_database.py | 39 ++++++++++++++++++++++
3 files changed, 57 insertions(+), 4 deletions(-)
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index bd5183f0cf..22d6ec849c 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -126,16 +126,22 @@ class JSONDatabaseNode : public DatabaseNode {
}
Array<TuningRecord> results;
results.reserve(top_k);
- int counter = 0;
for (const TuningRecord& record : this->tuning_records_) {
+ if (!record->run_secs.defined() || record->run_secs.value().empty()) {
+ continue;
+ }
if (record->workload.same_as(workload) ||
WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
results.push_back(record);
- if (++counter == top_k) {
+ if (results.size() == static_cast<size_t>(top_k)) {
break;
}
}
}
+ if (results.size() < static_cast<size_t>(top_k)) {
+ LOG(WARNING) << "The size of the GetTopK result is smaller than
requested. There are not "
+ "enough valid records in the database for this
workload.";
+ }
return results;
}
diff --git a/src/meta_schedule/database/memory_database.cc
b/src/meta_schedule/database/memory_database.cc
index 24fba6dfa1..19178a35f4 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -61,8 +61,12 @@ class MemoryDatabaseNode : public DatabaseNode {
void CommitTuningRecord(const TuningRecord& record) final {
records.push_back(record); }
Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
+ CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative";
+ if (top_k == 0) {
+ return {};
+ }
std::vector<std::pair<double, TuningRecord>> results;
- results.reserve(this->records.size());
+ results.reserve(records.size());
for (const TuningRecord& record : records) {
if (!record->run_secs.defined()) {
continue;
@@ -83,7 +87,7 @@ class MemoryDatabaseNode : public DatabaseNode {
std::sort(results.begin(), results.end());
auto begin = results.begin();
auto end = results.end();
- if (static_cast<int>(results.size()) > top_k) {
+ if (results.size() > static_cast<size_t>(top_k)) {
end = begin + top_k;
}
Array<TuningRecord> ret;
@@ -92,6 +96,10 @@ class MemoryDatabaseNode : public DatabaseNode {
ret.push_back(begin->second);
++begin;
}
+ if (ret.size() < static_cast<size_t>(top_k)) {
+ LOG(WARNING) << "The size of the GetTopK result is smaller than
requested. There are not "
+ "enough valid records in the database for this
workload.";
+ }
return ret;
}
diff --git a/tests/python/unittest/test_meta_schedule_database.py
b/tests/python/unittest/test_meta_schedule_database.py
index 777c5589a1..4ec10b556c 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -18,6 +18,7 @@
"""Test Meta Schedule Database"""
import os.path as osp
import tempfile
+import pytest
from typing import Callable, Optional, List
import tvm
@@ -536,5 +537,43 @@ def test_meta_schedule_pydatabase_current():
assert ms.database.Database.current() == db
+def call_get_top_k(run_secs_list, database, k):
+ mod: IRModule = Matmul
+ workload = database.commit_workload(mod)
+ for run_secs in run_secs_list:
+ record = ms.database.TuningRecord(
+ _create_schedule(mod, _schedule_matmul).trace,
+ workload,
+ run_secs,
+ tvm.target.Target("llvm"),
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
+ )
+ database.commit_tuning_record(record)
+ return [[v.value for v in record.run_secs] for record in
database.get_top_k(workload, k)]
+
+
[email protected](
+ "k,expected",
+ [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0],
[1.5, 4.5]])],
+)
+def test_memory_database_get_top_k(k, expected):
+ run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+ database = ms.database.MemoryDatabase()
+ result = call_get_top_k(run_secs_list, database, k)
+ assert result == expected
+
+
[email protected](
+ "k,expected",
+ [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0],
[1.5, 4.5]])],
+)
+def test_json_database_get_top_k(k, expected):
+ run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+ with tempfile.TemporaryDirectory() as tmpdir:
+ database = _create_tmp_database(tmpdir)
+ result = call_get_top_k(run_secs_list, database, k)
+ assert result == expected
+
+
if __name__ == "__main__":
tvm.testing.main()