This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b08e8e4 [MetaSchedule] Add the missing HasWorkload interface to the
Database (#9756)
b08e8e4 is described below
commit b08e8e49b5ea0e406c3ca41610a131dd271f0d41
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 16 22:16:31 2021 -0800
[MetaSchedule] Add the missing HasWorkload interface to the Database (#9756)
---
include/tvm/meta_schedule/database.h | 25 ++++++++++++++++++++--
python/tvm/meta_schedule/database/database.py | 18 ++++++++++++++++
src/meta_schedule/database/database.cc | 6 +++++-
src/meta_schedule/database/json_database.cc | 4 ++++
.../python/unittest/test_meta_schedule_database.py | 20 ++++++++++++++++-
.../unittest/test_meta_schedule_task_scheduler.py | 3 +++
6 files changed, 72 insertions(+), 4 deletions(-)
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index 60c6898..f07d8e1 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -156,6 +156,12 @@ class DatabaseNode : public runtime::Object {
/*! \brief Default destructor */
virtual ~DatabaseNode() = default;
/*!
+ * \brief Check if the database has the given workload.
+ * \param mod The IRModule to be searched for.
+ * \return Whether the database has the given workload.
+ */
+ virtual bool HasWorkload(const IRModule& mod) = 0;
+ /*!
* \brief Look up or add workload to the database if missing.
* \param mod The IRModule to be searched for or added.
* \return The workload corresponding to the given IRModule.
@@ -187,6 +193,12 @@ class DatabaseNode : public runtime::Object {
class PyDatabaseNode : public DatabaseNode {
public:
/*!
+ * \brief The function type of `HasWorkload` method.
+ * \param mod The IRModule to be searched for.
+ * \return Whether the database has the given workload.
+ */
+ using FHasWorkload = runtime::TypedPackedFunc<bool(const IRModule&)>;
+ /*!
* \brief The function type of `CommitWorkload` method.
* \param mod The IRModule to be searched for or added.
* \return The workload corresponding to the given IRModule.
@@ -210,6 +222,8 @@ class PyDatabaseNode : public DatabaseNode {
*/
using FSize = runtime::TypedPackedFunc<int64_t()>;
+ /*! \brief The packed function to the `HasWorkload` function. */
+ FHasWorkload f_has_workload;
/*! \brief The packed function to the `CommitWorkload` function. */
FCommitWorkload f_commit_workload;
/*! \brief The packed function to the `CommitTuningRecord` function. */
@@ -223,13 +237,18 @@ class PyDatabaseNode : public DatabaseNode {
// PackedFuncs are all not visited, because the reflection system doesn't
take care of them,
// so it cannot be accessible on the python side. If there is such need
from the future,
// we can then add corresponding accessor methods to help access on python.
- //
+ // `f_has_workload` is not visited
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_size` is not visited
}
+ bool HasWorkload(const IRModule& mod) final {
+ ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not
implemented!";
+ return f_has_workload(mod);
+ }
+
Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload
method not implemented!";
return f_commit_workload(mod);
@@ -271,13 +290,15 @@ class Database : public runtime::ObjectRef {
bool allow_missing);
/*!
* \brief Create a database with customized methods on the python-side.
+ * \param f_has_workload The packed function of `HasWorkload`.
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
- TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload
f_commit_workload,
+ TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload
f_has_workload,
+ PyDatabaseNode::FCommitWorkload
f_commit_workload,
PyDatabaseNode::FCommitTuningRecord
f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FSize f_size);
diff --git a/python/tvm/meta_schedule/database/database.py
b/python/tvm/meta_schedule/database/database.py
index fd746e6..31822f0 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -147,6 +147,19 @@ class TuningRecord(Object):
class Database(Object):
"""The abstract database interface."""
+ def has_workload(self, mod: IRModule) -> bool:
+ """Check if the database has the given workload.
+ Parameters
+ ----------
+ mod : IRModule
+ The IRModule to be searched for.
+ Returns
+ -------
+ result : bool
+ Whether the database has the given workload.
+ """
+ return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore #
pylint: disable=no-member
+
def commit_workload(self, mod: IRModule) -> Workload:
"""Commit a workload to the database if missing.
@@ -208,6 +221,10 @@ class PyDatabase(Database):
"""Constructor."""
@check_override(self.__class__, Database)
+ def f_has_workload(mod: IRModule) -> bool:
+ return self.has_workload(mod)
+
+ @check_override(self.__class__, Database)
def f_commit_workload(mod: IRModule) -> Workload:
return self.commit_workload(mod)
@@ -225,6 +242,7 @@ class PyDatabase(Database):
self.__init_handle_by_constructor__(
_ffi_api.DatabasePyDatabase, # type: ignore # pylint:
disable=no-member
+ f_has_workload,
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index e67b3d1..fc7cc74 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -135,10 +135,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef&
json_obj, const Workload& w
/******** PyDatabase ********/
-Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload
f_commit_workload,
+Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
+ PyDatabaseNode::FCommitWorkload
f_commit_workload,
PyDatabaseNode::FCommitTuningRecord
f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FSize f_size) {
ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
+ n->f_has_workload = f_has_workload;
n->f_commit_workload = f_commit_workload;
n->f_commit_tuning_record = f_commit_tuning_record;
n->f_get_top_k = f_get_top_k;
@@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
.set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
+ .set_body_method<Database>(&DatabaseNode::HasWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
.set_body_method<Database>(&DatabaseNode::CommitWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index 3efb72e..2e76940 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -69,6 +69,10 @@ class JSONDatabaseNode : public DatabaseNode {
TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode);
public:
+ bool HasWorkload(const IRModule& mod) {
+ return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) !=
workloads2idx_.end();
+ }
+
Workload CommitWorkload(const IRModule& mod) {
// Try to insert `mod` into `workloads_`
decltype(this->workloads2idx_)::iterator it;
diff --git a/tests/python/unittest/test_meta_schedule_database.py
b/tests/python/unittest/test_meta_schedule_database.py
index 121ec2f..cb7761a 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -22,7 +22,6 @@ import tempfile
from typing import Callable
import pytest
-
import tvm
from tvm import tir
from tvm.ir.module import IRModule
@@ -132,6 +131,25 @@ def test_meta_schedule_database_create():
assert osp.exists(database.path_tuning_record)
+def test_meta_schedule_database_has_workload():
+ mod: IRModule = Matmul
+ missing_mod: IRModule = MatmulRelu
+ with tempfile.TemporaryDirectory() as tmpdir:
+ database = _create_tmp_database(tmpdir)
+ workload = database.commit_workload(mod)
+ record = TuningRecord(
+ _create_schedule(mod, _schedule_matmul).trace,
+ [1.5, 2.5, 1.8],
+ workload,
+ tvm.target.Target("llvm"),
+ ArgInfo.from_prim_func(func=mod["main"]), # pylint:
disable=unsubscriptable-object
+ )
+ database.commit_tuning_record(record)
+ assert len(database) == 1
+ assert database.has_workload(mod)
+ assert not database.has_workload(missing_mod)
+
+
def test_meta_schedule_database_add_entry():
mod: IRModule = Matmul
with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py
b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index edff355..7eb61ad 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -140,6 +140,9 @@ class DummyDatabase(PyDatabase):
self.records = []
self.workload_reg = []
+ def has_workload(self, mod: IRModule) -> bool:
+ return False
+
def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)