This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 19d970cab1 [MetaSchedule] Fix anchor-block flow with empty design
space generator (#14047)
19d970cab1 is described below
commit 19d970cab1d48648d8214883e0a5385d34715fdc
Author: Icemist <[email protected]>
AuthorDate: Wed Feb 22 12:19:05 2023 +0300
[MetaSchedule] Fix anchor-block flow with empty design space generator
(#14047)
---
include/tvm/meta_schedule/database.h | 7 ++++++-
python/tvm/meta_schedule/database/database.py | 2 +-
src/meta_schedule/database/database.cc | 15 +++++++++++++++
src/meta_schedule/database/json_database.cc | 12 +++---------
src/meta_schedule/database/memory_database.cc | 13 +++----------
src/meta_schedule/trace_apply.cc | 4 +++-
src/tir/schedule/utils.h | 8 ++++++++
.../unittest/test_meta_schedule_relay_integration.py | 14 +++++++++++---
8 files changed, 50 insertions(+), 25 deletions(-)
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index a1dd4a412e..bea9cc3e37 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -144,6 +144,11 @@ class TuningRecordNode : public runtime::Object {
* argument information.
*/
ObjectRef AsJSON() const;
+ /*!
+ * \brief Check if this tuning record has valid trace instructions and
successful run results.
+ * \return The check result.
+ */
+ bool IsValid() const;
};
/*!
@@ -210,7 +215,7 @@ class DatabaseNode : public runtime::Object {
*/
virtual void CommitTuningRecord(const TuningRecord& record) = 0;
/*!
- * \brief Get the top K tuning records of given workload from the database.
+ * \brief Get the top K valid tuning records of given workload from the
database.
* \param workload The workload to be searched for.
* \param top_k The number of top records to be returned.
* \return An array of top K tuning records for the given workload.
diff --git a/python/tvm/meta_schedule/database/database.py
b/python/tvm/meta_schedule/database/database.py
index b95cb1ddd7..6621db9633 100644
--- a/python/tvm/meta_schedule/database/database.py
+++ b/python/tvm/meta_schedule/database/database.py
@@ -205,7 +205,7 @@ class Database(Object):
_ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore #
pylint: disable=no-member
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
- """Get the top K tuning records of given workload from the database.
+ """Get the top K valid tuning records of given workload from the
database.
Parameters
----------
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index da1d1db8f1..649429f9bc 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -113,6 +113,21 @@ ObjectRef TuningRecordNode::AsJSON() const {
json_args_info};
}
+bool TuningRecordNode::IsValid() const {
+ if (!GetNumValidInstructions(trace->insts, /*remove_postproc*/ true)) {
+ return false;
+ }
+ if (run_secs.defined()) {
+ for (const auto& run_sec : run_secs.value()) {
+ // kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
+ if (run_sec.defined() && run_sec->value !=
SortTuningRecordByMeanRunSecs::kMaxMeanTime) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload&
workload) {
tir::Trace trace{nullptr};
Optional<Array<FloatImm>> run_secs{nullptr};
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index 0e51e262df..10ff89a7ce 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -128,13 +128,7 @@ class JSONDatabaseNode : public DatabaseNode {
results.reserve(top_k);
for (const TuningRecord& record : this->tuning_records_) {
auto run_secs = record->run_secs;
- if (!run_secs.defined() || run_secs.value().empty() ||
- std::all_of(run_secs.value().begin(), run_secs.value().end(),
- // kMaxMeanTime(1e10) is used as a stub for undefined
measurement times.
- [](tvm::FloatImm v) {
- return v.defined() &&
- v->value ==
SortTuningRecordByMeanRunSecs::kMaxMeanTime;
- })) {
+ if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
@@ -146,8 +140,8 @@ class JSONDatabaseNode : public DatabaseNode {
}
}
if (results.size() < static_cast<size_t>(top_k)) {
- LOG(WARNING) << "The size of the GetTopK result is smaller than
requested. There are not "
- "enough valid records in the database for this
workload.";
+ LOG(WARNING) << "Returned tuning records less than requested(" <<
results.size() << " of "
+ << top_k << " asked).";
}
return results;
}
diff --git a/src/meta_schedule/database/memory_database.cc
b/src/meta_schedule/database/memory_database.cc
index 533a86acac..b003606c9c 100644
--- a/src/meta_schedule/database/memory_database.cc
+++ b/src/meta_schedule/database/memory_database.cc
@@ -68,14 +68,7 @@ class MemoryDatabaseNode : public DatabaseNode {
std::vector<TuningRecord> results;
results.reserve(records.size());
for (const TuningRecord& record : records) {
- auto run_secs = record->run_secs;
- if (!run_secs.defined() || run_secs.value().empty() ||
- std::all_of(run_secs.value().begin(), run_secs.value().end(),
- // kMaxMeanTime(1e10) is used as a stub for undefined
measurement times.
- [](tvm::FloatImm v) {
- return v.defined() &&
- v->value ==
SortTuningRecordByMeanRunSecs::kMaxMeanTime;
- })) {
+ if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
@@ -88,8 +81,8 @@ class MemoryDatabaseNode : public DatabaseNode {
return {results.begin(), results.begin() + top_k};
} else {
if (results.size() < static_cast<size_t>(top_k)) {
- LOG(WARNING) << "The size of the GetTopK result is smaller than
requested. There are not "
- "enough valid records in the database for this
workload.";
+ LOG(WARNING) << "Returned tuning records less than requested(" <<
results.size() << " of "
+ << top_k << " asked).";
}
return results;
}
diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc
index 9213d414e1..e60fdf5b9d 100644
--- a/src/meta_schedule/trace_apply.cc
+++ b/src/meta_schedule/trace_apply.cc
@@ -75,7 +75,9 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace,
Target target) {
// Spatial blocks which are not referenced in the anchor trace will be
inlined here.
auto block_sref = sch->GetSRef(block);
if (IsSpatial(block_sref) && !get_block_names.count(name)) {
- if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(),
block_sref, false))) {
+ StmtSRef scopeRoot =
+ (name != "root") ? GetScopeRoot(sch->state(), block_sref, false) :
block_sref;
+ if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) {
last_block_idx = inline_todos.size();
}
inline_todos.push_back(name);
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index a6aced4632..df92c7f807 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -490,6 +490,14 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>&
inputs,
void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const
Array<ObjectRef>& new_outputs,
std::unordered_map<const Object*, const Object*>*
rv_map);
+/*!
+ * \brief Counts the number of trace instructions.
+ * \param insts The instructions representing a trace.
+ * \param remove_postproc If postprocessing instructions are removed.
+ * \return Number of instructions.
+ */
+int GetNumValidInstructions(const Array<Instruction>& insts, bool
remove_postproc);
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 90be1ec0a1..f1d74348db 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -724,7 +724,7 @@ def test_module_equality_ignore_ndarray():
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)
-def _test_anchor_tuning(target):
+def _test_anchor_tuning(target, space):
data_shape = (128, 128)
weight_shape1 = (128, 128)
weight_shape2 = (128, 128)
@@ -756,6 +756,7 @@ def _test_anchor_tuning(target):
target=target,
params=params,
work_dir=work_dir,
+ space=space,
max_trials_global=4,
strategy="replay-trace",
module_equality=module_equality,
@@ -779,8 +780,15 @@ def _test_anchor_tuning(target):
np.testing.assert_allclose(ref, out, atol=1e-3)
-def test_anchor_tuning_cpu():
- _test_anchor_tuning("llvm --num-cores=4")
[email protected](
+ "space",
+ [
+ ms.space_generator.PostOrderApply(),
+ ms.space_generator.PostOrderApply(sch_rules=[], postprocs=[],
mutator_probs={}),
+ ],
+)
+def test_anchor_tuning_cpu(space):
+ _test_anchor_tuning("llvm --num-cores=4", space)
def test_anchor_tuning_cpu_link_params():