This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b08e8e4  [MetaSchedule] Add the missing HasWorkload interface to the 
Database (#9756)
b08e8e4 is described below

commit b08e8e49b5ea0e406c3ca41610a131dd271f0d41
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 16 22:16:31 2021 -0800

    [MetaSchedule] Add the missing HasWorkload interface to the Database (#9756)
---
 include/tvm/meta_schedule/database.h               | 25 ++++++++++++++++++++--
 python/tvm/meta_schedule/database/database.py      | 18 ++++++++++++++++
 src/meta_schedule/database/database.cc             |  6 +++++-
 src/meta_schedule/database/json_database.cc        |  4 ++++
 .../python/unittest/test_meta_schedule_database.py | 20 ++++++++++++++++-
 .../unittest/test_meta_schedule_task_scheduler.py  |  3 +++
 6 files changed, 72 insertions(+), 4 deletions(-)

diff --git a/include/tvm/meta_schedule/database.h 
b/include/tvm/meta_schedule/database.h
index 60c6898..f07d8e1 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -156,6 +156,12 @@ class DatabaseNode : public runtime::Object {
   /*! \brief Default destructor */
   virtual ~DatabaseNode() = default;
   /*!
+   * \brief Check if the database has the given workload.
+   * \param mod The IRModule to be searched for.
+   * \return Whether the database has the given workload.
+   */
+  virtual bool HasWorkload(const IRModule& mod) = 0;
+  /*!
    * \brief Look up or add workload to the database if missing.
    * \param mod The IRModule to be searched for or added.
    * \return The workload corresponding to the given IRModule.
@@ -187,6 +193,12 @@ class DatabaseNode : public runtime::Object {
 class PyDatabaseNode : public DatabaseNode {
  public:
   /*!
+   * \brief The function type of `HasWorkload` method.
+   * \param mod The IRModule to be searched for.
+   * \return Whether the database has the given workload.
+   */
+  using FHasWorkload = runtime::TypedPackedFunc<bool(const IRModule&)>;
+  /*!
    * \brief The function type of `CommitWorkload` method.
    * \param mod The IRModule to be searched for or added.
    * \return The workload corresponding to the given IRModule.
@@ -210,6 +222,8 @@ class PyDatabaseNode : public DatabaseNode {
    */
   using FSize = runtime::TypedPackedFunc<int64_t()>;
 
+  /*! \brief The packed function to the `HasWorkload` function. */
+  FHasWorkload f_has_workload;
   /*! \brief The packed function to the `CommitWorkload` function. */
   FCommitWorkload f_commit_workload;
   /*! \brief The packed function to the `CommitTuningRecord` function. */
@@ -223,13 +237,18 @@ class PyDatabaseNode : public DatabaseNode {
     // PackedFuncs are all not visited, because the reflection system doesn't 
take care of them,
     // so it cannot be accessible on the python side. If there is such need 
from the future,
     // we can then add corresponding accessor methods to help access on python.
-    //
+    // `f_has_workload` is not visited
     // `f_commit_workload` is not visited
     // `f_commit_tuning_record` is not visited
     // `f_get_top_k` is not visited
     // `f_size` is not visited
   }
 
+  bool HasWorkload(const IRModule& mod) final {
+    ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not 
implemented!";
+    return f_has_workload(mod);
+  }
+
   Workload CommitWorkload(const IRModule& mod) final {
     ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload 
method not implemented!";
     return f_commit_workload(mod);
@@ -271,13 +290,15 @@ class Database : public runtime::ObjectRef {
                                        bool allow_missing);
   /*!
    * \brief Create a database with customized methods on the python-side.
+   * \param f_has_workload The packed function of `HasWorkload`.
    * \param f_commit_workload The packed function of `CommitWorkload`.
    * \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
    * \param f_get_top_k The packed function of `GetTopK`.
    * \param f_size The packed function of `Size`.
    * \return The created database.
    */
-  TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload 
f_commit_workload,
+  TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload 
f_has_workload,
+                                     PyDatabaseNode::FCommitWorkload 
f_commit_workload,
                                      PyDatabaseNode::FCommitTuningRecord 
f_commit_tuning_record,
                                      PyDatabaseNode::FGetTopK f_get_top_k,
                                      PyDatabaseNode::FSize f_size);
diff --git a/python/tvm/meta_schedule/database/database.py 
b/python/tvm/meta_schedule/database/database.py
index fd746e6..31822f0 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -147,6 +147,19 @@ class TuningRecord(Object):
 class Database(Object):
     """The abstract database interface."""
 
+    def has_workload(self, mod: IRModule) -> bool:
+        """Check if the database has the given workload.
+        Parameters
+        ----------
+        mod : IRModule
+            The IRModule to be searched for.
+        Returns
+        -------
+        result : bool
+            Whether the database has the given workload.
+        """
+        return _ffi_api.DatabaseHasWorkload(self, mod)  # type: ignore # 
pylint: disable=no-member
+
     def commit_workload(self, mod: IRModule) -> Workload:
         """Commit a workload to the database if missing.
 
@@ -208,6 +221,10 @@ class PyDatabase(Database):
         """Constructor."""
 
         @check_override(self.__class__, Database)
+        def f_has_workload(mod: IRModule) -> bool:
+            return self.has_workload(mod)
+
+        @check_override(self.__class__, Database)
         def f_commit_workload(mod: IRModule) -> Workload:
             return self.commit_workload(mod)
 
@@ -225,6 +242,7 @@ class PyDatabase(Database):
 
         self.__init_handle_by_constructor__(
             _ffi_api.DatabasePyDatabase,  # type: ignore  # pylint: 
disable=no-member
+            f_has_workload,
             f_commit_workload,
             f_commit_tuning_record,
             f_get_top_k,
diff --git a/src/meta_schedule/database/database.cc 
b/src/meta_schedule/database/database.cc
index e67b3d1..fc7cc74 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -135,10 +135,12 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& 
json_obj, const Workload& w
 
 /******** PyDatabase ********/
 
-Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload 
f_commit_workload,
+Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
+                              PyDatabaseNode::FCommitWorkload 
f_commit_workload,
                               PyDatabaseNode::FCommitTuningRecord 
f_commit_tuning_record,
                               PyDatabaseNode::FGetTopK f_get_top_k, 
PyDatabaseNode::FSize f_size) {
   ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
+  n->f_has_workload = f_has_workload;
   n->f_commit_workload = f_commit_workload;
   n->f_commit_tuning_record = f_commit_tuning_record;
   n->f_get_top_k = f_get_top_k;
@@ -166,6 +168,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
 TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
     .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
 
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
+TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
+    .set_body_method<Database>(&DatabaseNode::HasWorkload);
 TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
     .set_body_method<Database>(&DatabaseNode::CommitWorkload);
 TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
diff --git a/src/meta_schedule/database/json_database.cc 
b/src/meta_schedule/database/json_database.cc
index 3efb72e..2e76940 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -69,6 +69,10 @@ class JSONDatabaseNode : public DatabaseNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode);
 
  public:
+  bool HasWorkload(const IRModule& mod) {
+    return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != 
workloads2idx_.end();
+  }
+
   Workload CommitWorkload(const IRModule& mod) {
     // Try to insert `mod` into `workloads_`
     decltype(this->workloads2idx_)::iterator it;
diff --git a/tests/python/unittest/test_meta_schedule_database.py 
b/tests/python/unittest/test_meta_schedule_database.py
index 121ec2f..cb7761a 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -22,7 +22,6 @@ import tempfile
 from typing import Callable
 
 import pytest
-
 import tvm
 from tvm import tir
 from tvm.ir.module import IRModule
@@ -132,6 +131,25 @@ def test_meta_schedule_database_create():
         assert osp.exists(database.path_tuning_record)
 
 
+def test_meta_schedule_database_has_workload():
+    mod: IRModule = Matmul
+    missing_mod: IRModule = MatmulRelu
+    with tempfile.TemporaryDirectory() as tmpdir:
+        database = _create_tmp_database(tmpdir)
+        workload = database.commit_workload(mod)
+        record = TuningRecord(
+            _create_schedule(mod, _schedule_matmul).trace,
+            [1.5, 2.5, 1.8],
+            workload,
+            tvm.target.Target("llvm"),
+            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: 
disable=unsubscriptable-object
+        )
+        database.commit_tuning_record(record)
+        assert len(database) == 1
+        assert database.has_workload(mod)
+        assert not database.has_workload(missing_mod)
+
+
 def test_meta_schedule_database_add_entry():
     mod: IRModule = Matmul
     with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py 
b/tests/python/unittest/test_meta_schedule_task_scheduler.py
index edff355..7eb61ad 100644
--- a/tests/python/unittest/test_meta_schedule_task_scheduler.py
+++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py
@@ -140,6 +140,9 @@ class DummyDatabase(PyDatabase):
         self.records = []
         self.workload_reg = []
 
+    def has_workload(self, mod: IRModule) -> bool:
+        return False
+
     def commit_tuning_record(self, record: TuningRecord) -> None:
         self.records.append(record)
 

Reply via email to