This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 79cfb79 [M3c][MetaScheduler] Add ReplayFunc Search Strategy. (#9799)
79cfb79 is described below
commit 79cfb797ef084f6bf170e54bddc6e4439f00ca5f
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Jan 5 14:34:26 2022 -0800
[M3c][MetaScheduler] Add ReplayFunc Search Strategy. (#9799)
* Modify TuneContext, TaskScheduler & SearchStrategy functions.
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
* Retrigger CI.
* Add ReplayFunc and EvolutionarySearch strategy.
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
* Fix optional task name.
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
* Remove extra files.
* Fix things.
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
---
.../tvm/meta_schedule/search_strategy/__init__.py | 6 +-
.../meta_schedule/search_strategy/replay_func.py | 63 +++++++++
src/meta_schedule/mutator/mutator.cc | 57 ++++++++
src/meta_schedule/postproc/postproc.cc | 53 ++++++++
src/meta_schedule/search_strategy/replay_func.cc | 151 +++++++++++++++++++++
.../unittest/test_meta_schedule_search_strategy.py | 15 +-
6 files changed, 335 insertions(+), 10 deletions(-)
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py
b/python/tvm/meta_schedule/search_strategy/__init__.py
index 298cdae..f385b72 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/search_strategy/__init__.py
@@ -19,5 +19,7 @@ The tvm.meta_schedule.search_strategy package.
Meta Schedule search strategy utilizes the design spaces given
to generate measure candidates.
"""
-from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
-from .replay_trace import ReplayTrace
+
+from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate
+from .replay_trace import ReplayTrace, ReplayTraceConfig
+from .replay_func import ReplayFunc, ReplayFuncConfig
diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py
b/python/tvm/meta_schedule/search_strategy/replay_func.py
new file mode 100644
index 0000000..eacc277
--- /dev/null
+++ b/python/tvm/meta_schedule/search_strategy/replay_func.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Replay Trace Search Strategy"""
+from typing import NamedTuple
+
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .search_strategy import SearchStrategy
+
+
+@register_object("meta_schedule.ReplayFunc")
+class ReplayFunc(SearchStrategy):
+ """
+ Replay Func Search Strategy is a search strategy that generates measure
candidates by
+ calling a design space generator and transform the design space.
+
+ Parameters
+ ----------
+ num_trials_per_iter : int
+ Number of trials per iteration.
+ num_trials_total : int
+ Total number of trials.
+ """
+
+ num_trials_per_iter: int
+ num_trials_total: int
+
+ def __init__(
+ self,
+ num_trials_per_iter: int,
+ num_trials_total: int,
+ ):
+ """Constructor"""
+ self.__init_handle_by_constructor__(
+ _ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint:
disable=no-member
+ num_trials_per_iter,
+ num_trials_total,
+ )
+
+
+class ReplayFuncConfig(NamedTuple):
+ """Configuration for ReplayFunc"""
+
+ num_trials_per_iter: int
+ num_trials_total: int
+
+ def create_strategy(self) -> ReplayFunc:
+ return ReplayFunc(self.num_trials_per_iter, self.num_trials_total)
diff --git a/src/meta_schedule/mutator/mutator.cc
b/src/meta_schedule/mutator/mutator.cc
new file mode 100644
index 0000000..27383ad
--- /dev/null
+++ b/src/meta_schedule/mutator/mutator.cc
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+Mutator Mutator::PyMutator(
+ PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context,
//
+ PyMutatorNode::FApply f_apply,
//
+ PyMutatorNode::FAsString f_as_string) {
+ ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
+ n->f_initialize_with_tune_context =
std::move(f_initialize_with_tune_context);
+ n->f_apply = std::move(f_apply);
+ n->f_as_string = std::move(f_as_string);
+ return Mutator(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PyMutatorNode>([](const ObjectRef& n, ReprPrinter* p) {
+ const auto* self = n.as<PyMutatorNode>();
+ ICHECK(self);
+ PyMutatorNode::FAsString f_as_string = (*self).f_as_string;
+ ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not
implemented!";
+ p->stream << f_as_string();
+ });
+
+TVM_REGISTER_OBJECT_TYPE(MutatorNode);
+TVM_REGISTER_NODE_TYPE(PyMutatorNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext")
+ .set_body_method<Mutator>(&MutatorNode::InitializeWithTuneContext);
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
+ .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) ->
Optional<tir::Trace> {
+ TRandState seed_ = (seed != -1) ? seed :
support::LinearCongruentialEngine::DeviceRandom();
+ return self->Apply(trace, &seed_);
+ });
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/postproc/postproc.cc
b/src/meta_schedule/postproc/postproc.cc
new file mode 100644
index 0000000..ff069e2
--- /dev/null
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+Postproc Postproc::PyPostproc(
+ PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context,
//
+ PyPostprocNode::FApply f_apply,
//
+ PyPostprocNode::FAsString f_as_string) {
+ ObjectPtr<PyPostprocNode> n = make_object<PyPostprocNode>();
+ n->f_initialize_with_tune_context =
std::move(f_initialize_with_tune_context);
+ n->f_apply = std::move(f_apply);
+ n->f_as_string = std::move(f_as_string);
+ return Postproc(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PyPostprocNode>([](const ObjectRef& n, ReprPrinter* p) {
+ const auto* self = n.as<PyPostprocNode>();
+ ICHECK(self);
+ PyPostprocNode::FAsString f_as_string = (*self).f_as_string;
+ ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not
implemented!";
+ p->stream << f_as_string();
+ });
+
+TVM_REGISTER_OBJECT_TYPE(PostprocNode);
+TVM_REGISTER_NODE_TYPE(PyPostprocNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext")
+ .set_body_method<Postproc>(&PostprocNode::InitializeWithTuneContext);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method<Postproc>(&PostprocNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/search_strategy/replay_func.cc
b/src/meta_schedule/search_strategy/replay_func.cc
new file mode 100644
index 0000000..7592a8a
--- /dev/null
+++ b/src/meta_schedule/search_strategy/replay_func.cc
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief A search strategy that generates measure candidates using space
generator. */
+class ReplayFuncNode : public SearchStrategyNode {
+ public:
+ /*! \brief The state of the search strategy. */
+ struct State {
+ /*! \brief The search strategy itself */
+ ReplayFuncNode* self;
+ /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+ int st;
+ /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+ int ed;
+
+ explicit State(ReplayFuncNode* self) : self(self), st(0),
ed(self->num_trials_per_iter) {}
+
+ inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
+ inline void NotifyRunnerResults(const Array<RunnerResult>& results);
+ };
+
+ /*! \brief The number of trials per iteration. */
+ int num_trials_per_iter;
+ /*! \brief The number of total trials. */
+ int num_trials_total;
+
+ /*! \brief The module to be tuned. */
+ IRModule mod_{nullptr};
+ /*! \brief The metadata of the function arguments. */
+ Array<ArgInfo> args_info_{nullptr};
+ /*! \brief The post processors */
+ Array<Postproc> postprocs_{nullptr};
+ /*! \brief The space generator for measure candidates generation. */
+ SpaceGenerator space_generator_{nullptr};
+ /*! \brief The random state. -1 means using random number. */
+ TRandState rand_state_ = -1;
+ /*! \brief The state of the search strategy. */
+ std::unique_ptr<State> state_ = nullptr;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("num_trials_per_iter", &num_trials_per_iter);
+ v->Visit("num_trials_total", &num_trials_total);
+ // `space_generator_` is not visited
+ // `mod_` is not visited
+ // `args_info_` is not visited
+ // `num_threads_` is not visited
+ // `rand_state_` is not visited
+ // `state_` is not visited
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.ReplayFunc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode);
+
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ this->space_generator_ = context->space_generator.value();
+ this->mod_ = context->mod.value();
+ this->args_info_ =
ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value()));
+ this->postprocs_ = context->postprocs;
+ this->rand_state_ = ForkSeed(&context->rand_state);
+ this->state_.reset();
+ }
+
+ void PreTuning(const Array<tir::Schedule>& design_spaces) final {
+ ICHECK(this->state_ == nullptr);
+ this->state_ = std::make_unique<State>(this);
+ }
+
+ void PostTuning() final {
+ ICHECK(this->state_ != nullptr);
+ this->state_.reset();
+ }
+
+ Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
+ ICHECK(this->state_ != nullptr);
+ return this->state_->GenerateMeasureCandidates();
+ }
+
+ void NotifyRunnerResults(const TuneContext& context,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<RunnerResult>& results) final {
+ ICHECK(this->state_ != nullptr);
+ this->state_->NotifyRunnerResults(results);
+ }
+};
+
+inline Optional<Array<MeasureCandidate>>
ReplayFuncNode::State::GenerateMeasureCandidates() {
+ if (st >= self->num_trials_total) {
+ return NullOpt;
+ }
+ ed = std::min(ed, self->num_trials_total);
+ Array<MeasureCandidate> result;
+ for (int i = st; i < ed; i++) {
+ for (;;) {
+ Array<tir::Schedule> schs =
self->space_generator_->GenerateDesignSpace(self->mod_);
+ int design_space_index = tir::SampleInt(&self->rand_state_, 0,
schs.size());
+ tir::Schedule sch = schs[design_space_index];
+ sch->EnterPostproc();
+ bool failed = false;
+ for (const Postproc& proc : self->postprocs_) {
+ if (!proc->Apply(sch)) {
+ failed = true;
+ break;
+ }
+ }
+ if (!failed) {
+ result.push_back(MeasureCandidate(sch, self->args_info_));
+ break;
+ }
+ }
+ }
+ return result;
+}
+
+inline void ReplayFuncNode::State::NotifyRunnerResults(const
Array<RunnerResult>& results) {
+ st += self->num_trials_per_iter;
+ ed += self->num_trials_per_iter;
+}
+
+SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int
num_trials_total) {
+ ObjectPtr<ReplayFuncNode> n = make_object<ReplayFuncNode>();
+ n->num_trials_per_iter = num_trials_per_iter;
+ n->num_trials_total = num_trials_total;
+ return SearchStrategy(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ReplayFuncNode);
+TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc")
+ .set_body_typed(SearchStrategy::ReplayFunc);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py
b/tests/python/unittest/test_meta_schedule_search_strategy.py
index 668fca9..a4d3217 100644
--- a/tests/python/unittest/test_meta_schedule_search_strategy.py
+++ b/tests/python/unittest/test_meta_schedule_search_strategy.py
@@ -18,12 +18,11 @@
# pylint: disable=missing-function-docstring
import sys
import pytest
-from typing import List
-
import tvm
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.search_strategy import (
+ ReplayFunc,
ReplayTrace,
SearchStrategy,
)
@@ -75,17 +74,17 @@ def _schedule_matmul(sch: Schedule):
sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3)
[email protected]("TestClass", [ReplayTrace])
[email protected]("TestClass", [ReplayFunc, ReplayTrace])
def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint:
disable = invalid-name
num_trials_per_iter = 7
num_trials_total = 20
strategy = TestClass(num_trials_per_iter=num_trials_per_iter,
num_trials_total=num_trials_total)
- context = TuneContext(mod=Matmul,
space_generator=ScheduleFn(sch_fn=_schedule_matmul))
- context.space_generator.initialize_with_tune_context(context)
- spaces = context.space_generator.generate_design_space(context.mod)
+ tune_context = TuneContext(mod=Matmul,
space_generator=ScheduleFn(sch_fn=_schedule_matmul))
+ tune_context.space_generator.initialize_with_tune_context(tune_context)
+ spaces =
tune_context.space_generator.generate_design_space(tune_context.mod)
- strategy.initialize_with_tune_context(context)
+ strategy.initialize_with_tune_context(tune_context)
strategy.pre_tuning(spaces)
(correct_sch,) =
ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul)
num_trials_each_iter: List[int] = []
@@ -100,7 +99,7 @@ def test_meta_schedule_replay_func(TestClass:
SearchStrategy): # pylint: disabl
remove_decisions=(isinstance(strategy, ReplayTrace)),
)
runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54],
error_msg=None))
- strategy.notify_runner_results(context, candidates, runner_results)
+ strategy.notify_runner_results(tune_context, candidates,
runner_results)
candidates = strategy.generate_measure_candidates()
strategy.post_tuning()
assert num_trials_each_iter == [7, 7, 6]