This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 57a13a2324 [Metaschedule] Aligning get_top_k logic in MemoryDatabase 
and JSONDatabase (#13611)
57a13a2324 is described below

commit 57a13a23243a4e70ea67a2c54f402ff7e9017d8a
Author: Alexey Voronov <[email protected]>
AuthorDate: Thu Dec 15 04:51:54 2022 +0300

    [Metaschedule] Aligning get_top_k logic in MemoryDatabase and JSONDatabase 
(#13611)
    
    [Metaschedule] Align get_top_k logic in MemoryDatabase and JSONDatabase
---
 src/meta_schedule/database/json_database.cc        | 10 ++++--
 src/meta_schedule/database/memory_database.cc      | 12 +++++--
 .../python/unittest/test_meta_schedule_database.py | 39 ++++++++++++++++++++++
 3 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/src/meta_schedule/database/json_database.cc 
b/src/meta_schedule/database/json_database.cc
index bd5183f0cf..22d6ec849c 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -126,16 +126,22 @@ class JSONDatabaseNode : public DatabaseNode {
     }
     Array<TuningRecord> results;
     results.reserve(top_k);
-    int counter = 0;
     for (const TuningRecord& record : this->tuning_records_) {
+      if (!record->run_secs.defined() || record->run_secs.value().empty()) {
+        continue;
+      }
       if (record->workload.same_as(workload) ||
           WorkloadEqual(GetModuleEquality())(record->workload, workload)) {
         results.push_back(record);
-        if (++counter == top_k) {
+        if (results.size() == static_cast<size_t>(top_k)) {
           break;
         }
       }
     }
+    if (results.size() < static_cast<size_t>(top_k)) {
+      LOG(WARNING) << "The size of the GetTopK result is smaller than 
requested. There are not "
+                      "enough valid records in the database for this 
workload.";
+    }
     return results;
   }
 
diff --git a/src/meta_schedule/database/memory_database.cc 
b/src/meta_schedule/database/memory_database.cc
index 24fba6dfa1..19178a35f4 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -61,8 +61,12 @@ class MemoryDatabaseNode : public DatabaseNode {
   void CommitTuningRecord(const TuningRecord& record) final { 
records.push_back(record); }
 
   Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
+    CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative";
+    if (top_k == 0) {
+      return {};
+    }
     std::vector<std::pair<double, TuningRecord>> results;
-    results.reserve(this->records.size());
+    results.reserve(records.size());
     for (const TuningRecord& record : records) {
       if (!record->run_secs.defined()) {
         continue;
@@ -83,7 +87,7 @@ class MemoryDatabaseNode : public DatabaseNode {
     std::sort(results.begin(), results.end());
     auto begin = results.begin();
     auto end = results.end();
-    if (static_cast<int>(results.size()) > top_k) {
+    if (results.size() > static_cast<size_t>(top_k)) {
       end = begin + top_k;
     }
     Array<TuningRecord> ret;
@@ -92,6 +96,10 @@ class MemoryDatabaseNode : public DatabaseNode {
       ret.push_back(begin->second);
       ++begin;
     }
+    if (ret.size() < static_cast<size_t>(top_k)) {
+      LOG(WARNING) << "The size of the GetTopK result is smaller than 
requested. There are not "
+                      "enough valid records in the database for this 
workload.";
+    }
     return ret;
   }
 
diff --git a/tests/python/unittest/test_meta_schedule_database.py 
b/tests/python/unittest/test_meta_schedule_database.py
index 777c5589a1..4ec10b556c 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -18,6 +18,7 @@
 """Test Meta Schedule Database"""
 import os.path as osp
 import tempfile
+import pytest
 from typing import Callable, Optional, List
 
 import tvm
@@ -536,5 +537,43 @@ def test_meta_schedule_pydatabase_current():
         assert ms.database.Database.current() == db
 
 
+def call_get_top_k(run_secs_list, database, k):
+    mod: IRModule = Matmul
+    workload = database.commit_workload(mod)
+    for run_secs in run_secs_list:
+        record = ms.database.TuningRecord(
+            _create_schedule(mod, _schedule_matmul).trace,
+            workload,
+            run_secs,
+            tvm.target.Target("llvm"),
+            ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
+        )
+        database.commit_tuning_record(record)
+    return [[v.value for v in record.run_secs] for record in 
database.get_top_k(workload, k)]
+
+
[email protected](
+    "k,expected",
+    [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], 
[1.5, 4.5]])],
+)
+def test_memory_database_get_top_k(k, expected):
+    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+    database = ms.database.MemoryDatabase()
+    result = call_get_top_k(run_secs_list, database, k)
+    assert result == expected
+
+
[email protected](
+    "k,expected",
+    [(0, []), (3, [[0.0, 2.0], [2.0], [1.5, 4.5]]), (5, [[0.0, 2.0], [2.0], 
[1.5, 4.5]])],
+)
+def test_json_database_get_top_k(k, expected):
+    run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0]]
+    with tempfile.TemporaryDirectory() as tmpdir:
+        database = _create_tmp_database(tmpdir)
+        result = call_get_top_k(run_secs_list, database, k)
+    assert result == expected
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to