This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 79cfb79  [M3c][MetaScheduler] Add ReplayFunc Search Strategy. (#9799)
79cfb79 is described below

commit 79cfb797ef084f6bf170e54bddc6e4439f00ca5f
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Jan 5 14:34:26 2022 -0800

    [M3c][MetaScheduler] Add ReplayFunc Search Strategy. (#9799)
    
    * Modify TuneContext, TaskScheduler & SearchStrategy functions.
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    
    * Retrigger CI.
    
    * Add ReplayFunc and EvolutionarySearch strategy.
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    
    * Fix optional task name.
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    
    * Remove extra files.
    
    * Fix things.
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
---
 .../tvm/meta_schedule/search_strategy/__init__.py  |   6 +-
 .../meta_schedule/search_strategy/replay_func.py   |  63 +++++++++
 src/meta_schedule/mutator/mutator.cc               |  57 ++++++++
 src/meta_schedule/postproc/postproc.cc             |  53 ++++++++
 src/meta_schedule/search_strategy/replay_func.cc   | 151 +++++++++++++++++++++
 .../unittest/test_meta_schedule_search_strategy.py |  15 +-
 6 files changed, 335 insertions(+), 10 deletions(-)

diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py 
b/python/tvm/meta_schedule/search_strategy/__init__.py
index 298cdae..f385b72 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/search_strategy/__init__.py
@@ -19,5 +19,7 @@ The tvm.meta_schedule.search_strategy package.
 Meta Schedule search strategy utilizes the design spaces given
 to generate measure candidates.
 """
-from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
-from .replay_trace import ReplayTrace
+
+from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
+from .replay_trace import ReplayTrace, ReplayTraceConfig
+from .replay_func import ReplayFunc, ReplayFuncConfig
diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py 
b/python/tvm/meta_schedule/search_strategy/replay_func.py
new file mode 100644
index 0000000..eacc277
--- /dev/null
+++ b/python/tvm/meta_schedule/search_strategy/replay_func.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Replay Trace Search Strategy"""
+from typing import NamedTuple
+
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .search_strategy import SearchStrategy
+
+
+@register_object("meta_schedule.ReplayFunc")
+class ReplayFunc(SearchStrategy):
+    """
+    Replay Func Search Strategy is a search strategy that generates measure 
candidates by
+    calling a design space generator and transform the design space.
+
+    Parameters
+    ----------
+    num_trials_per_iter : int
+        Number of trials per iteration.
+    num_trials_total : int
+        Total number of trials.
+    """
+
+    num_trials_per_iter: int
+    num_trials_total: int
+
+    def __init__(
+        self,
+        num_trials_per_iter: int,
+        num_trials_total: int,
+    ):
+        """Constructor"""
+        self.__init_handle_by_constructor__(
+            _ffi_api.SearchStrategyReplayFunc,  # type: ignore # pylint: 
disable=no-member
+            num_trials_per_iter,
+            num_trials_total,
+        )
+
+
+class ReplayFuncConfig(NamedTuple):
+    """Configuration for ReplayFunc"""
+
+    num_trials_per_iter: int
+    num_trials_total: int
+
+    def create_strategy(self) -> ReplayFunc:
+        return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
diff --git a/src/meta_schedule/mutator/mutator.cc 
b/src/meta_schedule/mutator/mutator.cc
new file mode 100644
index 0000000..27383ad
--- /dev/null
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+Mutator Mutator::PyMutator(
+    PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context,  
//
+    PyMutatorNode::FApply f_apply,                                             
//
+    PyMutatorNode::FAsString f_as_string) {
+  ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
+  n->f_initialize_with_tune_context = 
std::move(f_initialize_with_tune_context);
+  n->f_apply = std::move(f_apply);
+  n->f_as_string = std::move(f_as_string);
+  return Mutator(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<PyMutatorNode>([](const ObjectRef& n, ReprPrinter* p) {
+      const auto* self = n.as<PyMutatorNode>();
+      ICHECK(self);
+      PyMutatorNode::FAsString f_as_string = (*self).f_as_string;
+      ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not 
implemented!";
+      p->stream << f_as_string();
+    });
+
+TVM_REGISTER_OBJECT_TYPE(MutatorNode);
+TVM_REGISTER_NODE_TYPE(PyMutatorNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext")
+    .set_body_method<Mutator>(&MutatorNode::InitializeWithTuneContext);
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
+    .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> 
Optional<tir::Trace> {
+      TRandState seed_ = (seed != -1) ? seed : 
support::LinearCongruentialEngine::DeviceRandom();
+      return self->Apply(trace, &seed_);
+    });
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/postproc/postproc.cc 
b/src/meta_schedule/postproc/postproc.cc
new file mode 100644
index 0000000..ff069e2
--- /dev/null
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+Postproc Postproc::PyPostproc(
+    PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, 
 //
+    PyPostprocNode::FApply f_apply,                                            
 //
+    PyPostprocNode::FAsString f_as_string) {
+  ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
+  n->f_initialize_with_tune_context = 
std::move(f_initialize_with_tune_context);
+  n->f_apply = std::move(f_apply);
+  n->f_as_string = std::move(f_as_string);
+  return Postproc(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<PyPostprocNode>([](const ObjectRef& n, ReprPrinter* p) {
+      const auto* self = n.as<PyPostprocNode>();
+      ICHECK(self);
+      PyPostprocNode::FAsString f_as_string = (*self).f_as_string;
+      ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not 
implemented!";
+      p->stream << f_as_string();
+    });
+
+TVM_REGISTER_OBJECT_TYPE(PostprocNode);
+TVM_REGISTER_NODE_TYPE(PyPostprocNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
+    .set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/search_strategy/replay_func.cc 
b/src/meta_schedule/search_strategy/replay_func.cc
new file mode 100644
index 0000000..7592a8a
--- /dev/null
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief A search strategy that generates measure candidates using space 
generator. */
+class ReplayFuncNode : public SearchStrategyNode {
+ public:
+  /*! \brief The state of the search strategy. */
+  struct State {
+    /*! \brief The search strategy itself */
+    ReplayFuncNode* self;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int st;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int ed;
+
+    explicit State(ReplayFuncNode* self) : self(self), st(0), 
ed(self->num_trials_per_iter) {}
+
+    inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
+    inline void NotifyRunnerResults(const Array<RunnerResult>& results);
+  };
+
+  /*! \brief The number of trials per iteration. */
+  int num_trials_per_iter;
+  /*! \brief The number of total trials. */
+  int num_trials_total;
+
+  /*! \brief The module to be tuned. */
+  IRModule mod_{nullptr};
+  /*! \brief The metadata of the function arguments. */
+  Array<ArgInfo> args_info_{nullptr};
+  /*! \brief The post processors */
+  Array<Postproc> postprocs_{nullptr};
+  /*! \brief The space generator for measure candidates generation. */
+  SpaceGenerator space_generator_{nullptr};
+  /*! \brief The random state. -1 means using random number. */
+  TRandState rand_state_ = -1;
+  /*! \brief The state of the search strategy. */
+  std::unique_ptr<State> state_ = nullptr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_trials_per_iter", &num_trials_per_iter);
+    v->Visit("num_trials_total", &num_trials_total);
+    // `space_generator_` is not visited
+    // `mod_` is not visited
+    // `args_info_` is not visited
+    // `num_threads_` is not visited
+    // `rand_state_` is not visited
+    // `state_` is not visited
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
+
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    this->space_generator_ = context->space_generator.value();
+    this->mod_ = context->mod.value();
+    this->args_info_ = 
ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
+    this->postprocs_ = context->postprocs;
+    this->rand_state_ = ForkSeed(&context->rand_state);
+    this->state_.reset();
+  }
+
+  void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+    ICHECK(this->state_ == nullptr);
+    this->state_ = std::make_unique<State>(this);
+  }
+
+  void PostTuning() final {
+    ICHECK(this->state_ != nullptr);
+    this->state_.reset();
+  }
+
+  Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
+    ICHECK(this->state_ != nullptr);
+    return this->state_->GenerateMeasureCandidates();
+  }
+
+  void NotifyRunnerResults(const TuneContext& context,
+                           const Array<MeasureCandidate>& measure_candidates,
+                           const Array<RunnerResult>& results) final {
+    ICHECK(this->state_ != nullptr);
+    this->state_->NotifyRunnerResults(results);
+  }
+};
+
+inline Optional<Array<MeasureCandidate>> 
ReplayFuncNode::State::GenerateMeasureCandidates() {
+  if (st >= self->num_trials_total) {
+    return NullOpt;
+  }
+  ed = std::min(ed, self->num_trials_total);
+  Array<MeasureCandidate> result;
+  for (int i = st; i < ed; i++) {
+    for (;;) {
+      Array<tir::Schedule> schs = 
self->space_generator_->GenerateDesignSpace(self->mod_);
+      int design_space_index = tir::SampleInt(&self->rand_state_, 0, 
schs.size());
+      tir::Schedule sch = schs[design_space_index];
+      sch->EnterPostproc();
+      bool failed = false;
+      for (const Postproc& proc : self->postprocs_) {
+        if (!proc->Apply(sch)) {
+          failed = true;
+          break;
+        }
+      }
+      if (!failed) {
+        result.push_back(MeasureCandidate(sch, self->args_info_));
+        break;
+      }
+    }
+  }
+  return result;
+}
+
+inline void ReplayFuncNode::State::NotifyRunnerResults(const 
Array<RunnerResult>& results) {
+  st += self->num_trials_per_iter;
+  ed += self->num_trials_per_iter;
+}
+
+SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int 
num_trials_total) {
+  ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
+  n->num_trials_per_iter = num_trials_per_iter;
+  n->num_trials_total = num_trials_total;
+  return SearchStrategy(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
+    .set_body_typed(SearchStrategy::ReplayFunc);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py 
b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 668fca9..a4d3217 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -18,12 +18,11 @@
 # pylint: disable=missing-function-docstring
 import sys
 import pytest
-from typing import List
-
 import tvm
 from tvm.meta_schedule import TuneContext
 from tvm.meta_schedule.runner import RunnerResult
 from tvm.meta_schedule.search_strategy import (
+    ReplayFunc,
     ReplayTrace,
     SearchStrategy,
 )
@@ -75,17 +74,17 @@ def _schedule_matmul(sch: Schedule):
     sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
 
 
[email protected]("TestClass", [ReplayTrace])
[email protected]("TestClass", [ReplayFunc, ReplayTrace])
 def test_meta_schedule_replay_func(TestClass: SearchStrategy):  # pylint: 
disable = invalid-name
     num_trials_per_iter = 7
     num_trials_total = 20
 
     strategy = TestClass(num_trials_per_iter=num_trials_per_iter, 
num_trials_total=num_trials_total)
-    context = TuneContext(mod=Matmul, 
space_generator=ScheduleFn(sch_fn=_schedule_matmul))
-    context.space_generator.initialize_with_tune_context(context)
-    spaces = context.space_generator.generate_design_space(context.mod)
+    tune_context = TuneContext(mod=Matmul, 
space_generator=ScheduleFn(sch_fn=_schedule_matmul))
+    tune_context.space_generator.initialize_with_tune_context(tune_context)
+    spaces = 
tune_context.space_generator.generate_design_space(tune_context.mod)
 
-    strategy.initialize_with_tune_context(context)
+    strategy.initialize_with_tune_context(tune_context)
     strategy.pre_tuning(spaces)
     (correct_sch,) = 
ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
     num_trials_each_iter: List[int] = []
@@ -100,7 +99,7 @@ def test_meta_schedule_replay_func(TestClass: 
SearchStrategy):  # pylint: disabl
                 remove_decisions=(isinstance(strategy, ReplayTrace)),
             )
             runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], 
error_msg=None))
-        strategy.notify_runner_results(context, candidates, runner_results)
+        strategy.notify_runner_results(tune_context, candidates, 
runner_results)
         candidates = strategy.generate_measure_candidates()
     strategy.post_tuning()
     assert num_trials_each_iter == [7, 7, 6]

Reply via email to