This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e61ad7ab82 [MetaSchedule] Add Profiler Support For Tuning Efficiency
Optimization (#11486)
e61ad7ab82 is described below
commit e61ad7ab826a73347280468e2da47f215f76e05d
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Jun 13 08:41:53 2022 -0700
[MetaSchedule] Add Profiler Support For Tuning Efficiency Optimization
(#11486)
Co-authored-by: Junru Shao <[email protected]>
---
include/tvm/meta_schedule/profiler.h | 103 ++++++++++++++++
python/tvm/meta_schedule/__init__.py | 1 +
python/tvm/meta_schedule/profiler.py | 76 ++++++++++++
.../testing/tune_relay_meta_schedule.py | 29 +++--
.../measure_callback/add_to_database.cc | 1 +
.../measure_callback/echo_statistics.cc | 1 +
.../measure_callback/measure_callback.cc | 1 +
.../measure_callback/remove_build_artifact.cc | 1 +
.../measure_callback/update_cost_model.cc | 6 +-
src/meta_schedule/profiler.cc | 134 +++++++++++++++++++++
.../search_strategy/evolutionary_search.cc | 74 ++++++------
src/meta_schedule/tune_context.cc | 11 +-
src/meta_schedule/utils.h | 1 +
.../python/unittest/test_meta_schedule_profiler.py | 46 +++++++
.../unittest/test_meta_schedule_search_strategy.py | 4 +-
15 files changed, 434 insertions(+), 55 deletions(-)
diff --git a/include/tvm/meta_schedule/profiler.h
b/include/tvm/meta_schedule/profiler.h
new file mode 100644
index 0000000000..0f6572cca9
--- /dev/null
+++ b/include/tvm/meta_schedule/profiler.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_PROFILER_H_
+#define TVM_META_SCHEDULE_PROFILER_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/target/target.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace meta_schedule {
+
+class ScopedTimer {
+ public:
+ ~ScopedTimer() {
+ if (deferred_ != nullptr) {
+ deferred_();
+ }
+ }
+
+ private:
+ friend class Profiler;
+
+ explicit ScopedTimer(runtime::TypedPackedFunc<void()> deferred) :
deferred_(deferred) {}
+ runtime::TypedPackedFunc<void()> deferred_;
+};
+
+/*! \brief A generic profiler */
+class ProfilerNode : public runtime::Object {
+ public:
+ /*! \brief The segments that are already profiled */
+ std::unordered_map<std::string, double> stats_sec;
+ /*! \brief Counter for the total time used */
+ runtime::PackedFunc total_timer;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `stats_sec` is not visited.
+ // `total_timer` is not visited.
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.Profiler";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object);
+
+ public:
+ /*! \brief Get the internal stats of the running time */
+ Map<String, FloatImm> Get() const;
+ /*! \brief Return a summary of profiling results as table format */
+ String Table() const;
+};
+
+/*!
+ * \brief Managed reference to ProfilerNode
+ * \sa ProfilerNode
+ */
+class Profiler : public runtime::ObjectRef {
+ public:
+ Profiler();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler,
runtime::ObjectRef, ProfilerNode);
+
+ /*! \brief Entering the scope of the context manager */
+ void EnterWithScope();
+ /*! \brief Exiting the scope of the context manager */
+ void ExitWithScope();
+ /*! \brief Returns the current profiler */
+ static Optional<Profiler> Current();
+ /*!
+ * \brief Profile the time usage in the given scope in the given name.
+ * \param name Name for the scope.
+ * \return A scope timer for time profiling.
+ */
+ static ScopedTimer TimedScope(String name);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_PROFILER_H_
diff --git a/python/tvm/meta_schedule/__init__.py
b/python/tvm/meta_schedule/__init__.py
index 0028fbdf4f..26cf446b10 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -30,6 +30,7 @@ from . import (
search_strategy,
space_generator,
)
+from .profiler import Profiler
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
diff --git a/python/tvm/meta_schedule/profiler.py
b/python/tvm/meta_schedule/profiler.py
new file mode 100644
index 0000000000..a83d0fa16e
--- /dev/null
+++ b/python/tvm/meta_schedule/profiler.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A context manager that profiles tuning time cost for different parts."""
+from __future__ import annotations
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, Optional
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+
+@register_object("meta_schedule.Profiler")
+class Profiler(Object):
+ """Tuning time profiler."""
+
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.Profiler, # type: ignore # pylint: disable=no-member
+ )
+
+ def get(self) -> Dict[str, float]:
+ """Get the profiling results in minutes"""
+ return _ffi_api.ProfilerGet(self) # type: ignore # pylint:
disable=no-member
+
+ def table(self) -> str:
+ """Get the profiling results in a table format"""
+ return _ffi_api.ProfilerTable(self) # type: ignore # pylint:
disable=no-member
+
+ def __enter__(self) -> "Profiler":
+ """Entering the scope of the context manager"""
+ _ffi_api.ProfilerEnterWithScope(self) # type: ignore # pylint:
disable=no-member
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None:
+ """Exiting the scope of the context manager"""
+ _ffi_api.ProfilerExitWithScope(self) # type: ignore # pylint:
disable=no-member
+
+ @staticmethod
+ def current() -> Optional["Profiler"]:
+ """Get the current profiler."""
+ return _ffi_api.ProfilerCurrent() # type: ignore # pylint:
disable=no-member
+
+ @staticmethod
+ def timeit(name: str):
+ """Timeit a block of code"""
+
+ @contextmanager
+ def _timeit():
+ try:
+ f = _ffi_api.ProfilerTimedScope(name) # type: ignore #
pylint: disable=no-member
+ yield
+ finally:
+ if f:
+ f()
+
+ return _timeit()
diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
index bd858e0f2d..ee26b6303d 100644
--- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
+++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
@@ -122,19 +122,22 @@ def main():
alloc_repeat=alloc_repeat,
max_workers=ARGS.rpc_workers,
)
- lib = ms.tune_relay(
- mod=mod,
- target=ARGS.target,
- config=ms.TuneConfig(
- strategy="evolutionary",
- num_trials_per_iter=64,
- max_trials_per_task=ARGS.num_trials,
- max_trials_global=ARGS.num_trials,
- ),
- runner=runner, # type: ignore
- work_dir=ARGS.work_dir,
- params=params,
- )
+ with ms.Profiler() as profiler:
+ lib = ms.tune_relay(
+ mod=mod,
+ target=ARGS.target,
+ config=ms.TuneConfig(
+ strategy="evolutionary",
+ num_trials_per_iter=64,
+ max_trials_per_task=ARGS.num_trials,
+ max_trials_global=ARGS.num_trials,
+ ),
+ runner=runner, # type: ignore
+ work_dir=ARGS.work_dir,
+ params=params,
+ )
+ print("Tuning Time:")
+ print(profiler.table())
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc
b/src/meta_schedule/measure_callback/add_to_database.cc
index 27b4e55a7d..e86da3720f 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -30,6 +30,7 @@ class AddToDatabaseNode : public MeasureCallbackNode {
if (!task_scheduler->database.defined()) {
return;
}
+ auto _ = Profiler::TimedScope("AddToDatabase");
TuneContext task = task_scheduler->tasks[task_id];
Database database = task_scheduler->database.value();
Workload workload = database->CommitWorkload(task->mod.value());
diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc
b/src/meta_schedule/measure_callback/echo_statistics.cc
index e45f98b52e..5f3dce06f0 100644
--- a/src/meta_schedule/measure_callback/echo_statistics.cc
+++ b/src/meta_schedule/measure_callback/echo_statistics.cc
@@ -82,6 +82,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
if (this->task_info.empty()) {
SetupTaskInfo(task_scheduler->tasks);
}
+ auto _ = Profiler::TimedScope("EchoStatistics");
ICHECK_EQ(measure_candidates.size(), builder_results.size());
ICHECK_EQ(measure_candidates.size(), runner_results.size());
int n = measure_candidates.size();
diff --git a/src/meta_schedule/measure_callback/measure_callback.cc
b/src/meta_schedule/measure_callback/measure_callback.cc
index c7851a6fad..e49f5216ec 100644
--- a/src/meta_schedule/measure_callback/measure_callback.cc
+++ b/src/meta_schedule/measure_callback/measure_callback.cc
@@ -27,6 +27,7 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler&
task_scheduler,
const Array<BuilderResult>& builds,
//
const Array<RunnerResult>& results) {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not
implemented!";
+ auto _ = Profiler::TimedScope(this->f_as_string());
return f_apply(task_scheduler, task_id, measure_candidates, builds, results);
}
diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc
b/src/meta_schedule/measure_callback/remove_build_artifact.cc
index 649636def1..67267dff91 100644
--- a/src/meta_schedule/measure_callback/remove_build_artifact.cc
+++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc
@@ -28,6 +28,7 @@ class RemoveBuildArtifactNode : public MeasureCallbackNode {
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
static const PackedFunc* f_rm =
runtime::Registry::Get("meta_schedule.remove_build_dir");
+ auto _ = Profiler::TimedScope("RemoveBuildArtifact");
for (const BuilderResult& build_result : builder_results) {
if (Optional<String> path = build_result->artifact_path) {
(*f_rm)(path.value());
diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc
b/src/meta_schedule/measure_callback/update_cost_model.cc
index 00f6f94eb7..5b6208581c 100644
--- a/src/meta_schedule/measure_callback/update_cost_model.cc
+++ b/src/meta_schedule/measure_callback/update_cost_model.cc
@@ -27,11 +27,11 @@ class UpdateCostModelNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates,
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
+ auto _ = Profiler::TimedScope("UpdateCostModel");
TuneContext task = task_scheduler->tasks[task_id];
- ICHECK(task_scheduler->cost_model.defined()) //
+ ICHECK(task_scheduler->cost_model.defined())
<< "Cost model must be defined for the task scheduler!";
- ICHECK(task->measure_candidates.defined()) //
- << "Task's measure candidates must be present!";
+ ICHECK(task->measure_candidates.defined()) << "Task's measure candidates
must be present!";
CostModel cost_model = task_scheduler->cost_model.value();
ICHECK_EQ(measure_candidates.size(), builder_results.size());
ICHECK_EQ(runner_results.size(), builder_results.size());
diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc
new file mode 100644
index 0000000000..d3f72bb705
--- /dev/null
+++ b/src/meta_schedule/profiler.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <algorithm>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/**************** Profiler ****************/
+
+Map<String, FloatImm> ProfilerNode::Get() const {
+ Map<String, FloatImm> ret;
+ for (const auto& kv : stats_sec) {
+ ret.Set(kv.first, FloatImm(DataType::Float(64), kv.second));
+ }
+ return ret;
+}
+
+String ProfilerNode::Table() const {
+ CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run
the profiler first.";
+ CHECK(stats_sec.count("Total"))
+ << "ValueError: The total time is not recorded. This method should be
called only after "
+ "exiting the profiler's with scope.";
+ double total = stats_sec.at("Total");
+ struct Entry {
+ String name;
+ double minutes;
+ double percentage;
+ bool operator<(const Entry& other) const { return percentage >
other.percentage; }
+ };
+ std::vector<Entry> table_entry;
+ for (const auto& kv : stats_sec) {
+ table_entry.push_back(Entry{kv.first, kv.second / 60.0, kv.second / total
* 100.0});
+ }
+ std::sort(table_entry.begin(), table_entry.end());
+ support::TablePrinter p;
+ p.Row() << "ID"
+ << "Name"
+ << "Time (min)"
+ << "Percentage";
+ p.Separator();
+ for (int i = 0, n = table_entry.size(); i < n; ++i) {
+ if (i == 0) {
+ p.Row() << "" << table_entry[i].name << table_entry[i].minutes <<
table_entry[i].percentage;
+ } else {
+ p.Row() << i << table_entry[i].name << table_entry[i].minutes <<
table_entry[i].percentage;
+ }
+ }
+ return p.AsStr();
+}
+
+Profiler::Profiler() {
+ ObjectPtr<ProfilerNode> n = make_object<ProfilerNode>();
+ n->stats_sec.clear();
+ n->total_timer = nullptr;
+ data_ = n;
+}
+
+PackedFunc ProfilerTimedScope(String name) {
+ if (Optional<Profiler> opt_profiler = Profiler::Current()) {
+ return TypedPackedFunc<void()>([profiler = opt_profiler.value(),
//
+ tik =
std::chrono::high_resolution_clock::now(), //
+ name = std::move(name)]() {
+ auto tok = std::chrono::high_resolution_clock::now();
+ double duration = std::chrono::duration_cast<std::chrono::seconds>(tok -
tik).count();
+ profiler->stats_sec[name] += duration;
+ });
+ }
+ return nullptr;
+}
+
+ScopedTimer Profiler::TimedScope(String name) { return
ScopedTimer(ProfilerTimedScope(name)); }
+
+/**************** Context Manager ****************/
+
+std::vector<Profiler>* ThreadLocalProfilers() {
+ static thread_local std::vector<Profiler> profilers;
+ return &profilers;
+}
+
+void Profiler::EnterWithScope() {
+ ThreadLocalProfilers()->push_back(*this);
+ (*this)->total_timer = ProfilerTimedScope("Total");
+}
+
+void Profiler::ExitWithScope() {
+ ThreadLocalProfilers()->pop_back();
+ if ((*this)->total_timer != nullptr) {
+ (*this)->total_timer();
+ (*this)->total_timer = nullptr;
+ }
+}
+
+Optional<Profiler> Profiler::Current() {
+ std::vector<Profiler>* profilers = ThreadLocalProfilers();
+ if (profilers->empty()) {
+ return NullOpt;
+ } else {
+ return profilers->back();
+ }
+}
+
+TVM_REGISTER_NODE_TYPE(ProfilerNode);
+TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler {
+ return Profiler();
+});
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope")
+ .set_body_method(&Profiler::EnterWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope")
+ .set_body_method(&Profiler::ExitWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method<Profiler>(&ProfilerNode::Get);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method<Profiler>(&ProfilerNode::Table);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc
b/src/meta_schedule/search_strategy/evolutionary_search.cc
index 6935ee610e..acde7f65a8 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -220,6 +220,7 @@ Array<MeasureCandidate> AssembleCandidates(const
std::vector<Schedule>& picks,
std::vector<double> PredictNormalizedScore(const std::vector<Schedule>&
candidates,
const TuneContext& context, const
CostModel& cost_model,
const Array<ArgInfo>& args_info) {
+ auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore");
ICHECK(!candidates.empty()) << "Candidates given for score prediction can
not be empty list!";
std::vector<double> scores =
cost_model->Predict(context, AssembleCandidates(candidates, args_info));
@@ -437,6 +438,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
};
std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int
num) {
+ auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase");
std::vector<tir::Trace> measured_traces;
measured_traces.reserve(num);
Array<TuningRecord> top_records = this->database_->GetTopK(this->token_,
num);
@@ -466,6 +468,7 @@ std::vector<Schedule>
EvolutionarySearchNode::State::PickBestFromDatabase(int nu
}
std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int
num) {
+ auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation");
ThreadedTraceApply pp(self->context_->postprocs);
std::vector<Schedule> out_schs;
while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
@@ -529,43 +532,46 @@ std::vector<Schedule>
EvolutionarySearchNode::State::EvolveWithCostModel(
ConcurrentBitmask cbmask(self->population_size);
std::vector<Schedule> next_population(self->population_size,
Schedule{nullptr});
// The worker function
- auto f_find_candidate = [&cbmask, &population, &next_population, &pp,
this](int thread_id,
-
int trace_id) {
- // Prepare samplers
- PerThreadData& data = this->per_thread_data_.at(thread_id);
- TRandState* rand_state = &data.rand_state;
- const IRModule& mod = data.mod;
- std::function<int()>& trace_sampler = data.trace_sampler;
- std::function<Optional<Mutator>()>& mutator_sampler =
data.mutator_sampler;
- Schedule& result = next_population.at(trace_id);
- int sampled_trace_id = -1;
- // Loop until success
- for (int fail_count = 0; fail_count <= self->genetic_max_fail_count;
++fail_count) {
- sampled_trace_id = trace_sampler();
- tir::Trace trace = population.at(sampled_trace_id)->trace().value();
- if (Optional<Mutator> opt_mutator = mutator_sampler()) {
- // Decision: mutate
- Mutator mutator = opt_mutator.value();
- if (Optional<tir::Trace> new_trace = mutator->Apply(trace,
rand_state)) {
- if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(),
rand_state)) {
- // note that sch's trace is different from new_trace
- // because it contains post-processing information
- result = sch.value();
- break;
+ {
+ auto _ = Profiler::TimedScope("EvoSearch/Evolve/Mutation");
+ auto f_find_candidate = [&cbmask, &population, &next_population, &pp,
this](int thread_id,
+
int trace_id) {
+ // Prepare samplers
+ PerThreadData& data = this->per_thread_data_.at(thread_id);
+ TRandState* rand_state = &data.rand_state;
+ const IRModule& mod = data.mod;
+ std::function<int()>& trace_sampler = data.trace_sampler;
+ std::function<Optional<Mutator>()>& mutator_sampler =
data.mutator_sampler;
+ Schedule& result = next_population.at(trace_id);
+ int sampled_trace_id = -1;
+ // Loop until success
+ for (int fail_count = 0; fail_count <= self->genetic_max_fail_count;
++fail_count) {
+ sampled_trace_id = trace_sampler();
+ tir::Trace trace = population.at(sampled_trace_id)->trace().value();
+ if (Optional<Mutator> opt_mutator = mutator_sampler()) {
+ // Decision: mutate
+ Mutator mutator = opt_mutator.value();
+ if (Optional<tir::Trace> new_trace = mutator->Apply(trace,
rand_state)) {
+ if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(),
rand_state)) {
+ // note that sch's trace is different from new_trace
+ // because it contains post-processing information
+ result = sch.value();
+ break;
+ }
}
+ } else if (cbmask.QueryAndMark(sampled_trace_id)) {
+ // Decision: do not mutate
+ break;
}
- } else if (cbmask.QueryAndMark(sampled_trace_id)) {
- // Decision: do not mutate
- break;
}
- }
- // if retry count exceeds the limit, reuse an old sample
- if (!result.defined()) {
- result = population.at(sampled_trace_id);
- }
- };
- support::parallel_for_dynamic(0, self->population_size,
self->context_->num_threads,
- f_find_candidate);
+ // if retry count exceeds the limit, reuse an old sample
+ if (!result.defined()) {
+ result = population.at(sampled_trace_id);
+ }
+ };
+ support::parallel_for_dynamic(0, self->population_size,
self->context_->num_threads,
+ f_find_candidate);
+ }
population.swap(next_population);
TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter
<< " done. Summary:\n"
<< pp.SummarizeFailures();
diff --git a/src/meta_schedule/tune_context.cc
b/src/meta_schedule/tune_context.cc
index 362db0a380..0c70dcf5c4 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -75,6 +75,7 @@ void TuneContextNode::_SetMeasureCandidates(const
Array<MeasureCandidate>& candi
}
void TuneContextNode::_SendToBuilder(const Builder& builder) {
+ auto _ = Profiler::TimedScope("SendToBuilder");
Array<MeasureCandidate> candidates = this->measure_candidates.value();
Target target = this->target.value();
Array<BuilderInput> inputs;
@@ -86,6 +87,7 @@ void TuneContextNode::_SendToBuilder(const Builder& builder) {
}
void TuneContextNode::_SendToRunner(const Runner& runner) {
+ auto _ = Profiler::TimedScope("SendToRunner");
Array<MeasureCandidate> candidates = this->measure_candidates.value();
Array<BuilderResult> builder_results = this->builder_results.value();
Target target = this->target.value();
@@ -133,9 +135,12 @@ Array<RunnerResult> TuneContextNode::_Join() {
Array<RunnerFuture> futures = this->runner_futures.value();
int n = futures.size();
Array<RunnerResult> results;
- results.reserve(n);
- for (RunnerFuture future : futures) {
- results.push_back(future->Result());
+ {
+ auto _ = Profiler::TimedScope("JoinRunnerFutures");
+ results.reserve(n);
+ for (RunnerFuture future : futures) {
+ results.push_back(future->Result());
+ }
}
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
results);
ICHECK(this->measure_candidates.defined());
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 40c301c617..c399696a82 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -28,6 +28,7 @@
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/feature_extractor.h>
#include <tvm/meta_schedule/measure_callback.h>
+#include <tvm/meta_schedule/profiler.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/search_strategy.h>
diff --git a/tests/python/unittest/test_meta_schedule_profiler.py
b/tests/python/unittest/test_meta_schedule_profiler.py
new file mode 100644
index 0000000000..36a3d634ba
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_profiler.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test Meta Schedule Profiler """
+import time
+
+from tvm import meta_schedule as ms
+
+
+def test_meta_schedule_profiler_context_manager():
+ with ms.Profiler() as profiler:
+ time.sleep(1)
+ with ms.Profiler.timeit("Level0"):
+ time.sleep(1)
+ with ms.Profiler.timeit("Level1"):
+ time.sleep(2)
+ # Note that the results are in seconds
+
+ result = profiler.get()
+ assert len(result) == 3
+ assert 3.9 <= result["Total"] <= 4.1
+ assert 2.9 <= result["Level0"] <= 3.1
+ assert 1.9 <= result["Level1"] <= 2.1
+
+
+def test_meta_schedule_no_context():
+ with ms.Profiler.timeit("Level0"):
+ assert ms.Profiler.current() is None
+
+
+if __name__ == "__main__":
+ test_meta_schedule_profiler_context_manager()
+ test_meta_schedule_no_context()
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py
b/tests/python/unittest/test_meta_schedule_search_strategy.py
index fd8c023b5e..1201e4100a 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -119,7 +119,7 @@ def test_meta_schedule_replay_func(
assert num_trials_each_iter == [7, 7, 6]
-def test_meta_schedule_evolutionary_search(): # pylint: disable =
invalid-name]
+def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name
def _schedule_matmul_small(sch: Schedule):
block = sch.get_block("matmul")
_, j, k = sch.get_loops(block=block)
@@ -185,7 +185,7 @@ def test_meta_schedule_evolutionary_search(): # pylint:
disable = invalid-name]
assert num_trials_each_iter.count(0) < 5
-def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable =
invalid-name]
+def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable =
invalid-name
def _schedule_matmul_empty(sch: Schedule):
return sch