This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e61ad7ab82 [MetaSchedule] Add Profiler Support For Tuning Efficiency 
Optimization (#11486)
e61ad7ab82 is described below

commit e61ad7ab826a73347280468e2da47f215f76e05d
Author: Xiyou Zhou <[email protected]>
AuthorDate: Mon Jun 13 08:41:53 2022 -0700

    [MetaSchedule] Add Profiler Support For Tuning Efficiency Optimization 
(#11486)
    
    Co-authored-by: Junru Shao <[email protected]>
---
 include/tvm/meta_schedule/profiler.h               | 103 ++++++++++++++++
 python/tvm/meta_schedule/__init__.py               |   1 +
 python/tvm/meta_schedule/profiler.py               |  76 ++++++++++++
 .../testing/tune_relay_meta_schedule.py            |  29 +++--
 .../measure_callback/add_to_database.cc            |   1 +
 .../measure_callback/echo_statistics.cc            |   1 +
 .../measure_callback/measure_callback.cc           |   1 +
 .../measure_callback/remove_build_artifact.cc      |   1 +
 .../measure_callback/update_cost_model.cc          |   6 +-
 src/meta_schedule/profiler.cc                      | 134 +++++++++++++++++++++
 .../search_strategy/evolutionary_search.cc         |  74 ++++++------
 src/meta_schedule/tune_context.cc                  |  11 +-
 src/meta_schedule/utils.h                          |   1 +
 .../python/unittest/test_meta_schedule_profiler.py |  46 +++++++
 .../unittest/test_meta_schedule_search_strategy.py |   4 +-
 15 files changed, 434 insertions(+), 55 deletions(-)

diff --git a/include/tvm/meta_schedule/profiler.h 
b/include/tvm/meta_schedule/profiler.h
new file mode 100644
index 0000000000..0f6572cca9
--- /dev/null
+++ b/include/tvm/meta_schedule/profiler.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_PROFILER_H_
+#define TVM_META_SCHEDULE_PROFILER_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/container/string.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/target/target.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace meta_schedule {
+
+class ScopedTimer {
+ public:
+  ~ScopedTimer() {
+    if (deferred_ != nullptr) {
+      deferred_();
+    }
+  }
+
+ private:
+  friend class Profiler;
+
+  explicit ScopedTimer(runtime::TypedPackedFunc<void()> deferred) : 
deferred_(deferred) {}
+  runtime::TypedPackedFunc<void()> deferred_;
+};
+
+/*! \brief A generic profiler */
+class ProfilerNode : public runtime::Object {
+ public:
+  /*! \brief The segments that are already profiled */
+  std::unordered_map<std::string, double> stats_sec;
+  /*! \brief Counter for the total time used */
+  runtime::PackedFunc total_timer;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // `stats_sec` is not visited.
+    // `total_timer` is not visited.
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.Profiler";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object);
+
+ public:
+  /*! \brief Get the internal stats of the running time */
+  Map<String, FloatImm> Get() const;
+  /*! \brief Return a summary of profiling results as table format */
+  String Table() const;
+};
+
+/*!
+ * \brief Managed reference to ProfilerNode
+ * \sa ProfilerNode
+ */
+class Profiler : public runtime::ObjectRef {
+ public:
+  Profiler();
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler, 
runtime::ObjectRef, ProfilerNode);
+
+  /*! \brief Entering the scope of the context manager */
+  void EnterWithScope();
+  /*! \brief Exiting the scope of the context manager */
+  void ExitWithScope();
+  /*! \brief Returns the current profiler */
+  static Optional<Profiler> Current();
+  /*!
+   * \brief Profile the time usage in the given scope in the given name.
+   * \param name Name for the scope.
+   * \return A scope timer for time profiling.
+   */
+  static ScopedTimer TimedScope(String name);
+};
+
+}  // namespace meta_schedule
+}  // namespace tvm
+
+#endif  // TVM_META_SCHEDULE_PROFILER_H_
diff --git a/python/tvm/meta_schedule/__init__.py 
b/python/tvm/meta_schedule/__init__.py
index 0028fbdf4f..26cf446b10 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -30,6 +30,7 @@ from . import (
     search_strategy,
     space_generator,
 )
+from .profiler import Profiler
 from .apply_history_best import ApplyHistoryBest
 from .extracted_task import ExtractedTask
 from .relay_integration import extract_task_from_relay
diff --git a/python/tvm/meta_schedule/profiler.py 
b/python/tvm/meta_schedule/profiler.py
new file mode 100644
index 0000000000..a83d0fa16e
--- /dev/null
+++ b/python/tvm/meta_schedule/profiler.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A context manager that profiles tuning time cost for different parts."""
+from __future__ import annotations
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, Optional
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+@register_object("meta_schedule.Profiler")
+class Profiler(Object):
+    """Tuning time profiler."""
+
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.Profiler,  # type: ignore # pylint: disable=no-member
+        )
+
+    def get(self) -> Dict[str, float]:
+        """Get the profiling results in minutes"""
+        return _ffi_api.ProfilerGet(self)  # type: ignore # pylint: 
disable=no-member
+
+    def table(self) -> str:
+        """Get the profiling results in a table format"""
+        return _ffi_api.ProfilerTable(self)  # type: ignore # pylint: 
disable=no-member
+
+    def __enter__(self) -> "Profiler":
+        """Entering the scope of the context manager"""
+        _ffi_api.ProfilerEnterWithScope(self)  # type: ignore # pylint: 
disable=no-member
+        return self
+
+    def __exit__(self, ptype, value, trace) -> None:
+        """Exiting the scope of the context manager"""
+        _ffi_api.ProfilerExitWithScope(self)  # type: ignore # pylint: 
disable=no-member
+
+    @staticmethod
+    def current() -> Optional["Profiler"]:
+        """Get the current profiler."""
+        return _ffi_api.ProfilerCurrent()  # type: ignore # pylint: 
disable=no-member
+
+    @staticmethod
+    def timeit(name: str):
+        """Timeit a block of code"""
+
+        @contextmanager
+        def _timeit():
+            try:
+                f = _ffi_api.ProfilerTimedScope(name)  # type: ignore # 
pylint: disable=no-member
+                yield
+            finally:
+                if f:
+                    f()
+
+        return _timeit()
diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py 
b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
index bd858e0f2d..ee26b6303d 100644
--- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
+++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
@@ -122,19 +122,22 @@ def main():
         alloc_repeat=alloc_repeat,
         max_workers=ARGS.rpc_workers,
     )
-    lib = ms.tune_relay(
-        mod=mod,
-        target=ARGS.target,
-        config=ms.TuneConfig(
-            strategy="evolutionary",
-            num_trials_per_iter=64,
-            max_trials_per_task=ARGS.num_trials,
-            max_trials_global=ARGS.num_trials,
-        ),
-        runner=runner,  # type: ignore
-        work_dir=ARGS.work_dir,
-        params=params,
-    )
+    with ms.Profiler() as profiler:
+        lib = ms.tune_relay(
+            mod=mod,
+            target=ARGS.target,
+            config=ms.TuneConfig(
+                strategy="evolutionary",
+                num_trials_per_iter=64,
+                max_trials_per_task=ARGS.num_trials,
+                max_trials_global=ARGS.num_trials,
+            ),
+            runner=runner,  # type: ignore
+            work_dir=ARGS.work_dir,
+            params=params,
+        )
+    print("Tuning Time:")
+    print(profiler.table())
     graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
     for input_name, input_shape in input_info.items():
         if input_dtype.startswith("float"):
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc 
b/src/meta_schedule/measure_callback/add_to_database.cc
index 27b4e55a7d..e86da3720f 100644
--- a/src/meta_schedule/measure_callback/add_to_database.cc
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -30,6 +30,7 @@ class AddToDatabaseNode : public MeasureCallbackNode {
     if (!task_scheduler->database.defined()) {
       return;
     }
+    auto _ = Profiler::TimedScope("AddToDatabase");
     TuneContext task = task_scheduler->tasks[task_id];
     Database database = task_scheduler->database.value();
     Workload workload = database->CommitWorkload(task->mod.value());
diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc 
b/src/meta_schedule/measure_callback/echo_statistics.cc
index e45f98b52e..5f3dce06f0 100644
--- a/src/meta_schedule/measure_callback/echo_statistics.cc
+++ b/src/meta_schedule/measure_callback/echo_statistics.cc
@@ -82,6 +82,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
     if (this->task_info.empty()) {
       SetupTaskInfo(task_scheduler->tasks);
     }
+    auto _ = Profiler::TimedScope("EchoStatistics");
     ICHECK_EQ(measure_candidates.size(), builder_results.size());
     ICHECK_EQ(measure_candidates.size(), runner_results.size());
     int n = measure_candidates.size();
diff --git a/src/meta_schedule/measure_callback/measure_callback.cc 
b/src/meta_schedule/measure_callback/measure_callback.cc
index c7851a6fad..e49f5216ec 100644
--- a/src/meta_schedule/measure_callback/measure_callback.cc
+++ b/src/meta_schedule/measure_callback/measure_callback.cc
@@ -27,6 +27,7 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler& 
task_scheduler,
                                   const Array<BuilderResult>& builds,          
       //
                                   const Array<RunnerResult>& results) {
   ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not 
implemented!";
+  auto _ = Profiler::TimedScope(this->f_as_string());
   return f_apply(task_scheduler, task_id, measure_candidates, builds, results);
 }
 
diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc 
b/src/meta_schedule/measure_callback/remove_build_artifact.cc
index 649636def1..67267dff91 100644
--- a/src/meta_schedule/measure_callback/remove_build_artifact.cc
+++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc
@@ -28,6 +28,7 @@ class RemoveBuildArtifactNode : public MeasureCallbackNode {
              const Array<BuilderResult>& builder_results,
              const Array<RunnerResult>& runner_results) final {
     static const PackedFunc* f_rm = 
runtime::Registry::Get("meta_schedule.remove_build_dir");
+    auto _ = Profiler::TimedScope("RemoveBuildArtifact");
     for (const BuilderResult& build_result : builder_results) {
       if (Optional<String> path = build_result->artifact_path) {
         (*f_rm)(path.value());
diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc 
b/src/meta_schedule/measure_callback/update_cost_model.cc
index 00f6f94eb7..5b6208581c 100644
--- a/src/meta_schedule/measure_callback/update_cost_model.cc
+++ b/src/meta_schedule/measure_callback/update_cost_model.cc
@@ -27,11 +27,11 @@ class UpdateCostModelNode : public MeasureCallbackNode {
              const Array<MeasureCandidate>& measure_candidates,
              const Array<BuilderResult>& builder_results,
              const Array<RunnerResult>& runner_results) final {
+    auto _ = Profiler::TimedScope("UpdateCostModel");
     TuneContext task = task_scheduler->tasks[task_id];
-    ICHECK(task_scheduler->cost_model.defined())  //
+    ICHECK(task_scheduler->cost_model.defined())
         << "Cost model must be defined for the task scheduler!";
-    ICHECK(task->measure_candidates.defined())  //
-        << "Task's measure candidates must be present!";
+    ICHECK(task->measure_candidates.defined()) << "Task's measure candidates 
must be present!";
     CostModel cost_model = task_scheduler->cost_model.value();
     ICHECK_EQ(measure_candidates.size(), builder_results.size());
     ICHECK_EQ(runner_results.size(), builder_results.size());
diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc
new file mode 100644
index 0000000000..d3f72bb705
--- /dev/null
+++ b/src/meta_schedule/profiler.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <algorithm>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/**************** Profiler ****************/
+
+Map<String, FloatImm> ProfilerNode::Get() const {
+  Map<String, FloatImm> ret;
+  for (const auto& kv : stats_sec) {
+    ret.Set(kv.first, FloatImm(DataType::Float(64), kv.second));
+  }
+  return ret;
+}
+
+String ProfilerNode::Table() const {
+  CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run 
the profiler first.";
+  CHECK(stats_sec.count("Total"))
+      << "ValueError: The total time is not recorded. This method should be 
called only after "
+         "exiting the profiler's with scope.";
+  double total = stats_sec.at("Total");
+  struct Entry {
+    String name;
+    double minutes;
+    double percentage;
+    bool operator<(const Entry& other) const { return percentage > 
other.percentage; }
+  };
+  std::vector<Entry> table_entry;
+  for (const auto& kv : stats_sec) {
+    table_entry.push_back(Entry{kv.first, kv.second / 60.0, kv.second / total 
* 100.0});
+  }
+  std::sort(table_entry.begin(), table_entry.end());
+  support::TablePrinter p;
+  p.Row() << "ID"
+          << "Name"
+          << "Time (min)"
+          << "Percentage";
+  p.Separator();
+  for (int i = 0, n = table_entry.size(); i < n; ++i) {
+    if (i == 0) {
+      p.Row() << "" << table_entry[i].name << table_entry[i].minutes << 
table_entry[i].percentage;
+    } else {
+      p.Row() << i << table_entry[i].name << table_entry[i].minutes << 
table_entry[i].percentage;
+    }
+  }
+  return p.AsStr();
+}
+
+Profiler::Profiler() {
+  ObjectPtr<ProfilerNode> n = make_object<ProfilerNode>();
+  n->stats_sec.clear();
+  n->total_timer = nullptr;
+  data_ = n;
+}
+
+PackedFunc ProfilerTimedScope(String name) {
+  if (Optional<Profiler> opt_profiler = Profiler::Current()) {
+    return TypedPackedFunc<void()>([profiler = opt_profiler.value(),           
       //
+                                    tik = 
std::chrono::high_resolution_clock::now(),  //
+                                    name = std::move(name)]() {
+      auto tok = std::chrono::high_resolution_clock::now();
+      double duration = std::chrono::duration_cast<std::chrono::seconds>(tok - 
tik).count();
+      profiler->stats_sec[name] += duration;
+    });
+  }
+  return nullptr;
+}
+
+ScopedTimer Profiler::TimedScope(String name) { return 
ScopedTimer(ProfilerTimedScope(name)); }
+
+/**************** Context Manager ****************/
+
+std::vector<Profiler>* ThreadLocalProfilers() {
+  static thread_local std::vector<Profiler> profilers;
+  return &profilers;
+}
+
+void Profiler::EnterWithScope() {
+  ThreadLocalProfilers()->push_back(*this);
+  (*this)->total_timer = ProfilerTimedScope("Total");
+}
+
+void Profiler::ExitWithScope() {
+  ThreadLocalProfilers()->pop_back();
+  if ((*this)->total_timer != nullptr) {
+    (*this)->total_timer();
+    (*this)->total_timer = nullptr;
+  }
+}
+
+Optional<Profiler> Profiler::Current() {
+  std::vector<Profiler>* profilers = ThreadLocalProfilers();
+  if (profilers->empty()) {
+    return NullOpt;
+  } else {
+    return profilers->back();
+  }
+}
+
+TVM_REGISTER_NODE_TYPE(ProfilerNode);
+TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler {
+  return Profiler();
+});
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope")
+    .set_body_method(&Profiler::EnterWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope")
+    .set_body_method(&Profiler::ExitWithScope);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method<Profiler>(&ProfilerNode::Get);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method<Profiler>(&ProfilerNode::Table);
+TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc 
b/src/meta_schedule/search_strategy/evolutionary_search.cc
index 6935ee610e..acde7f65a8 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -220,6 +220,7 @@ Array<MeasureCandidate> AssembleCandidates(const 
std::vector<Schedule>& picks,
 std::vector<double> PredictNormalizedScore(const std::vector<Schedule>& 
candidates,
                                            const TuneContext& context, const 
CostModel& cost_model,
                                            const Array<ArgInfo>& args_info) {
+  auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore");
   ICHECK(!candidates.empty()) << "Candidates given for score prediction can 
not be empty list!";
   std::vector<double> scores =
       cost_model->Predict(context, AssembleCandidates(candidates, args_info));
@@ -437,6 +438,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
 };
 
 std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int 
num) {
+  auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase");
   std::vector<tir::Trace> measured_traces;
   measured_traces.reserve(num);
   Array<TuningRecord> top_records = this->database_->GetTopK(this->token_, 
num);
@@ -466,6 +468,7 @@ std::vector<Schedule> 
EvolutionarySearchNode::State::PickBestFromDatabase(int nu
 }
 
 std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int 
num) {
+  auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation");
   ThreadedTraceApply pp(self->context_->postprocs);
   std::vector<Schedule> out_schs;
   while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
@@ -529,43 +532,46 @@ std::vector<Schedule> 
EvolutionarySearchNode::State::EvolveWithCostModel(
     ConcurrentBitmask cbmask(self->population_size);
     std::vector<Schedule> next_population(self->population_size, 
Schedule{nullptr});
     // The worker function
-    auto f_find_candidate = [&cbmask, &population, &next_population, &pp, 
this](int thread_id,
-                                                                               
 int trace_id) {
-      // Prepare samplers
-      PerThreadData& data = this->per_thread_data_.at(thread_id);
-      TRandState* rand_state = &data.rand_state;
-      const IRModule& mod = data.mod;
-      std::function<int()>& trace_sampler = data.trace_sampler;
-      std::function<Optional<Mutator>()>& mutator_sampler = 
data.mutator_sampler;
-      Schedule& result = next_population.at(trace_id);
-      int sampled_trace_id = -1;
-      // Loop until success
-      for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; 
++fail_count) {
-        sampled_trace_id = trace_sampler();
-        tir::Trace trace = population.at(sampled_trace_id)->trace().value();
-        if (Optional<Mutator> opt_mutator = mutator_sampler()) {
-          // Decision: mutate
-          Mutator mutator = opt_mutator.value();
-          if (Optional<tir::Trace> new_trace = mutator->Apply(trace, 
rand_state)) {
-            if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(), 
rand_state)) {
-              // note that sch's trace is different from new_trace
-              // because it contains post-processing information
-              result = sch.value();
-              break;
+    {
+      auto _ = Profiler::TimedScope("EvoSearch/Evolve/Mutation");
+      auto f_find_candidate = [&cbmask, &population, &next_population, &pp, 
this](int thread_id,
+                                                                               
   int trace_id) {
+        // Prepare samplers
+        PerThreadData& data = this->per_thread_data_.at(thread_id);
+        TRandState* rand_state = &data.rand_state;
+        const IRModule& mod = data.mod;
+        std::function<int()>& trace_sampler = data.trace_sampler;
+        std::function<Optional<Mutator>()>& mutator_sampler = 
data.mutator_sampler;
+        Schedule& result = next_population.at(trace_id);
+        int sampled_trace_id = -1;
+        // Loop until success
+        for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; 
++fail_count) {
+          sampled_trace_id = trace_sampler();
+          tir::Trace trace = population.at(sampled_trace_id)->trace().value();
+          if (Optional<Mutator> opt_mutator = mutator_sampler()) {
+            // Decision: mutate
+            Mutator mutator = opt_mutator.value();
+            if (Optional<tir::Trace> new_trace = mutator->Apply(trace, 
rand_state)) {
+              if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(), 
rand_state)) {
+                // note that sch's trace is different from new_trace
+                // because it contains post-processing information
+                result = sch.value();
+                break;
+              }
             }
+          } else if (cbmask.QueryAndMark(sampled_trace_id)) {
+            // Decision: do not mutate
+            break;
           }
-        } else if (cbmask.QueryAndMark(sampled_trace_id)) {
-          // Decision: do not mutate
-          break;
         }
-      }
-      // if retry count exceeds the limit, reuse an old sample
-      if (!result.defined()) {
-        result = population.at(sampled_trace_id);
-      }
-    };
-    support::parallel_for_dynamic(0, self->population_size, 
self->context_->num_threads,
-                                  f_find_candidate);
+        // if retry count exceeds the limit, reuse an old sample
+        if (!result.defined()) {
+          result = population.at(sampled_trace_id);
+        }
+      };
+      support::parallel_for_dynamic(0, self->population_size, 
self->context_->num_threads,
+                                    f_find_candidate);
+    }
     population.swap(next_population);
     TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter 
<< " done. Summary:\n"
                                                    << pp.SummarizeFailures();
diff --git a/src/meta_schedule/tune_context.cc 
b/src/meta_schedule/tune_context.cc
index 362db0a380..0c70dcf5c4 100644
--- a/src/meta_schedule/tune_context.cc
+++ b/src/meta_schedule/tune_context.cc
@@ -75,6 +75,7 @@ void TuneContextNode::_SetMeasureCandidates(const 
Array<MeasureCandidate>& candi
 }
 
 void TuneContextNode::_SendToBuilder(const Builder& builder) {
+  auto _ = Profiler::TimedScope("SendToBuilder");
   Array<MeasureCandidate> candidates = this->measure_candidates.value();
   Target target = this->target.value();
   Array<BuilderInput> inputs;
@@ -86,6 +87,7 @@ void TuneContextNode::_SendToBuilder(const Builder& builder) {
 }
 
 void TuneContextNode::_SendToRunner(const Runner& runner) {
+  auto _ = Profiler::TimedScope("SendToRunner");
   Array<MeasureCandidate> candidates = this->measure_candidates.value();
   Array<BuilderResult> builder_results = this->builder_results.value();
   Target target = this->target.value();
@@ -133,9 +135,12 @@ Array<RunnerResult> TuneContextNode::_Join() {
   Array<RunnerFuture> futures = this->runner_futures.value();
   int n = futures.size();
   Array<RunnerResult> results;
-  results.reserve(n);
-  for (RunnerFuture future : futures) {
-    results.push_back(future->Result());
+  {
+    auto _ = Profiler::TimedScope("JoinRunnerFutures");
+    results.reserve(n);
+    for (RunnerFuture future : futures) {
+      results.push_back(future->Result());
+    }
   }
   
this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(),
 results);
   ICHECK(this->measure_candidates.defined());
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 40c301c617..c399696a82 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -28,6 +28,7 @@
 #include <tvm/meta_schedule/database.h>
 #include <tvm/meta_schedule/feature_extractor.h>
 #include <tvm/meta_schedule/measure_callback.h>
+#include <tvm/meta_schedule/profiler.h>
 #include <tvm/meta_schedule/runner.h>
 #include <tvm/meta_schedule/schedule_rule.h>
 #include <tvm/meta_schedule/search_strategy.h>
diff --git a/tests/python/unittest/test_meta_schedule_profiler.py 
b/tests/python/unittest/test_meta_schedule_profiler.py
new file mode 100644
index 0000000000..36a3d634ba
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_profiler.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test Meta Schedule Profiler """
+import time
+
+from tvm import meta_schedule as ms
+
+
+def test_meta_schedule_profiler_context_manager():
+    with ms.Profiler() as profiler:
+        time.sleep(1)
+        with ms.Profiler.timeit("Level0"):
+            time.sleep(1)
+            with ms.Profiler.timeit("Level1"):
+                time.sleep(2)
+    # Note that the results are in seconds
+
+    result = profiler.get()
+    assert len(result) == 3
+    assert 3.9 <= result["Total"] <= 4.1
+    assert 2.9 <= result["Level0"] <= 3.1
+    assert 1.9 <= result["Level1"] <= 2.1
+
+
+def test_meta_schedule_no_context():
+    with ms.Profiler.timeit("Level0"):
+        assert ms.Profiler.current() is None
+
+
+if __name__ == "__main__":
+    test_meta_schedule_profiler_context_manager()
+    test_meta_schedule_no_context()
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py 
b/tests/python/unittest/test_meta_schedule_search_strategy.py
index fd8c023b5e..1201e4100a 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -119,7 +119,7 @@ def test_meta_schedule_replay_func(
     assert num_trials_each_iter == [7, 7, 6]
 
 
-def test_meta_schedule_evolutionary_search():  # pylint: disable = 
invalid-name]
+def test_meta_schedule_evolutionary_search():  # pylint: disable = invalid-name
     def _schedule_matmul_small(sch: Schedule):
         block = sch.get_block("matmul")
         _, j, k = sch.get_loops(block=block)
@@ -185,7 +185,7 @@ def test_meta_schedule_evolutionary_search():  # pylint: 
disable = invalid-name]
     assert num_trials_each_iter.count(0) < 5
 
 
-def test_meta_schedule_evolutionary_search_early_stop():  # pylint: disable = 
invalid-name]
+def test_meta_schedule_evolutionary_search_early_stop():  # pylint: disable = 
invalid-name
     def _schedule_matmul_empty(sch: Schedule):
         return sch
 

Reply via email to