This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 19d970cab1 [MetaSchedule] Fix anchor-block flow with empty design 
space generator (#14047)
19d970cab1 is described below

commit 19d970cab1d48648d8214883e0a5385d34715fdc
Author: Icemist <[email protected]>
AuthorDate: Wed Feb 22 12:19:05 2023 +0300

    [MetaSchedule] Fix anchor-block flow with empty design space generator 
(#14047)
---
 include/tvm/meta_schedule/database.h                      |  7 ++++++-
 python/tvm/meta_schedule/database/database.py             |  2 +-
 src/meta_schedule/database/database.cc                    | 15 +++++++++++++++
 src/meta_schedule/database/json_database.cc               | 12 +++---------
 src/meta_schedule/database/memory_database.cc             | 13 +++----------
 src/meta_schedule/trace_apply.cc                          |  4 +++-
 src/tir/schedule/utils.h                                  |  8 ++++++++
 .../unittest/test_meta_schedule_relay_integration.py      | 14 +++++++++++---
 8 files changed, 50 insertions(+), 25 deletions(-)

diff --git a/include/tvm/meta_schedule/database.h 
b/include/tvm/meta_schedule/database.h
index a1dd4a412e..bea9cc3e37 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -144,6 +144,11 @@ class TuningRecordNode : public runtime::Object {
    * argument information.
    */
   ObjectRef AsJSON() const;
+  /*!
+   * \brief Check if this tuning record has valid trace instructions and 
successful run results.
+   * \return The check result.
+   */
+  bool IsValid() const;
 };
 
 /*!
@@ -210,7 +215,7 @@ class DatabaseNode : public runtime::Object {
    */
   virtual void CommitTuningRecord(const TuningRecord& record) = 0;
   /*!
-   * \brief Get the top K tuning records of given workload from the database.
+   * \brief Get the top K valid tuning records of given workload from the 
database.
    * \param workload The workload to be searched for.
    * \param top_k The number of top records to be returned.
    * \return An array of top K tuning records for the given workload.
diff --git a/python/tvm/meta_schedule/database/database.py 
b/python/tvm/meta_schedule/database/database.py
index b95cb1ddd7..6621db9633 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -205,7 +205,7 @@ class Database(Object):
         _ffi_api.DatabaseCommitTuningRecord(self, record)  # type: ignore # 
pylint: disable=no-member
 
     def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
-        """Get the top K tuning records of given workload from the database.
+        """Get the top K valid tuning records of given workload from the 
database.
 
         Parameters
         ----------
diff --git a/src/meta_schedule/database/database.cc 
b/src/meta_schedule/database/database.cc
index da1d1db8f1..649429f9bc 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -113,6 +113,21 @@ ObjectRef TuningRecordNode::AsJSON() const {
                           json_args_info};
 }
 
+bool TuningRecordNode::IsValid() const {
+  if (!GetNumValidInstructions(trace->insts, /*remove_postproc*/ true)) {
+    return false;
+  }
+  if (run_secs.defined()) {
+    for (const auto& run_sec : run_secs.value()) {
+      // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+      if (run_sec.defined() && run_sec->value != 
SortTuningRecordByMeanRunSecs::kMaxMeanTime) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& 
workload) {
   tir::Trace trace{nullptr};
   Optional<Array<FloatImm>> run_secs{nullptr};
diff --git a/src/meta_schedule/database/json_database.cc 
b/src/meta_schedule/database/json_database.cc
index 0e51e262df..10ff89a7ce 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -128,13 +128,7 @@ class JSONDatabaseNode : public DatabaseNode {
     results.reserve(top_k);
     for (const TuningRecord& record : this->tuning_records_) {
       auto run_secs = record->run_secs;
-      if (!run_secs.defined() || run_secs.value().empty() ||
-          std::all_of(run_secs.value().begin(), run_secs.value().end(),
-                      // kMaxMeanTime(1e10) is used as a stub for undefined 
measurement times.
-                      [](tvm::FloatImm v) {
-                        return v.defined() &&
-                               v->value == 
SortTuningRecordByMeanRunSecs::kMaxMeanTime;
-                      })) {
+      if (!record->IsValid()) {
         continue;
       }
       if (record->workload.same_as(workload) ||
@@ -146,8 +140,8 @@ class JSONDatabaseNode : public DatabaseNode {
       }
     }
     if (results.size() < static_cast<size_t>(top_k)) {
-      LOG(WARNING) << "The size of the GetTopK result is smaller than 
requested. There are not "
-                      "enough valid records in the database for this 
workload.";
+      LOG(WARNING) << "Returned tuning records less than requested(" << 
results.size() << " of "
+                   << top_k << " asked).";
     }
     return results;
   }
diff --git a/src/meta_schedule/database/memory_database.cc 
b/src/meta_schedule/database/memory_database.cc
index 533a86acac..b003606c9c 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -68,14 +68,7 @@ class MemoryDatabaseNode : public DatabaseNode {
     std::vector<TuningRecord> results;
     results.reserve(records.size());
     for (const TuningRecord& record : records) {
-      auto run_secs = record->run_secs;
-      if (!run_secs.defined() || run_secs.value().empty() ||
-          std::all_of(run_secs.value().begin(), run_secs.value().end(),
-                      // kMaxMeanTime(1e10) is used as a stub for undefined 
measurement times.
-                      [](tvm::FloatImm v) {
-                        return v.defined() &&
-                               v->value == 
SortTuningRecordByMeanRunSecs::kMaxMeanTime;
-                      })) {
+      if (!record->IsValid()) {
         continue;
       }
       if (record->workload.same_as(workload) ||
@@ -88,8 +81,8 @@ class MemoryDatabaseNode : public DatabaseNode {
       return {results.begin(), results.begin() + top_k};
     } else {
       if (results.size() < static_cast<size_t>(top_k)) {
-        LOG(WARNING) << "The size of the GetTopK result is smaller than 
requested. There are not "
-                        "enough valid records in the database for this 
workload.";
+        LOG(WARNING) << "Returned tuning records less than requested(" << 
results.size() << " of "
+                     << top_k << " asked).";
       }
       return results;
     }
diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc
index 9213d414e1..e60fdf5b9d 100644
--- a/src/meta_schedule/trace_apply.cc
+++ b/src/meta_schedule/trace_apply.cc
@@ -75,7 +75,9 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, 
Target target) {
     // Spatial blocks which are not referenced in the anchor trace will be 
inlined here.
     auto block_sref = sch->GetSRef(block);
     if (IsSpatial(block_sref) && !get_block_names.count(name)) {
-      if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), 
block_sref, false))) {
+      StmtSRef scopeRoot =
+          (name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : 
block_sref;
+      if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) {
         last_block_idx = inline_todos.size();
       }
       inline_todos.push_back(name);
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index a6aced4632..df92c7f807 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -490,6 +490,14 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& 
inputs,
 void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const 
Array<ObjectRef>& new_outputs,
                            std::unordered_map<const Object*, const Object*>* 
rv_map);
 
+/*!
+ * \brief Counts the number of trace instructions.
+ * \param insts The instructions representing a trace.
+ * \param remove_postproc If postprocessing instructions are removed.
+ * \return Number of instructions.
+ */
+int GetNumValidInstructions(const Array<Instruction>& insts, bool 
remove_postproc);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py 
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 90be1ec0a1..f1d74348db 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -724,7 +724,7 @@ def test_module_equality_ignore_ndarray():
     np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)
 
 
-def _test_anchor_tuning(target):
+def _test_anchor_tuning(target, space):
     data_shape = (128, 128)
     weight_shape1 = (128, 128)
     weight_shape2 = (128, 128)
@@ -756,6 +756,7 @@ def _test_anchor_tuning(target):
             target=target,
             params=params,
             work_dir=work_dir,
+            space=space,
             max_trials_global=4,
             strategy="replay-trace",
             module_equality=module_equality,
@@ -779,8 +780,15 @@ def _test_anchor_tuning(target):
     np.testing.assert_allclose(ref, out, atol=1e-3)
 
 
-def test_anchor_tuning_cpu():
-    _test_anchor_tuning("llvm --num-cores=4")
[email protected](
+    "space",
+    [
+        ms.space_generator.PostOrderApply(),
+        ms.space_generator.PostOrderApply(sch_rules=[], postprocs=[], 
mutator_probs={}),
+    ],
+)
+def test_anchor_tuning_cpu(space):
+    _test_anchor_tuning("llvm --num-cores=4", space)
 
 
 def test_anchor_tuning_cpu_link_params():

Reply via email to