This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 96a513cd97 Patch replay trace. (#11621)
96a513cd97 is described below
commit 96a513cd97be4b42acb51d1c9b73288820e90185
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Jun 8 11:39:42 2022 -0700
Patch replay trace. (#11621)
---
include/tvm/meta_schedule/search_strategy.h | 4 +++-
.../tvm/meta_schedule/search_strategy/replay_trace.py | 8 +++++++-
src/meta_schedule/search_strategy/replay_trace.cc | 18 +++++++++++++++---
3 files changed, 25 insertions(+), 5 deletions(-)
diff --git a/include/tvm/meta_schedule/search_strategy.h
b/include/tvm/meta_schedule/search_strategy.h
index baae22f0d9..5e249850f5 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -211,8 +211,10 @@ class SearchStrategy : public runtime::ObjectRef {
* \brief Constructor of replay trace search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the
batch size.
* \param max_trials_per_task The total number of trials for trace replaying.
+ * \param max_fail_count The max number of failures during trace replaying.
*/
- TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int
max_trials_per_task);
+ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int
max_trials_per_task,
+ int max_fail_count);
/*!
* \brief Constructor of replay func search strategy.
diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py
b/python/tvm/meta_schedule/search_strategy/replay_trace.py
index 70461d65f7..36dbb8734e 100644
--- a/python/tvm/meta_schedule/search_strategy/replay_trace.py
+++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py
@@ -33,15 +33,21 @@ class ReplayTrace(SearchStrategy):
Number of trials per iteration.
max_trials_per_task : int
Total number of trials for one task
+ max_fail_count : int
+ Max number of failures during trace replaying.
"""
num_trials_per_iter: int
max_trials_per_task: int
+ max_fail_count: int
- def __init__(self, num_trials_per_iter: int, max_trials_per_task: int):
+ def __init__(
+ self, num_trials_per_iter: int, max_trials_per_task: int,
max_fail_count: int = 100
+ ):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint:
disable=no-member
num_trials_per_iter,
max_trials_per_task,
+ max_fail_count,
)
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc
b/src/meta_schedule/search_strategy/replay_trace.cc
index 13f32a744e..355f71455d 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -60,6 +60,8 @@ class ReplayTraceNode : public SearchStrategyNode {
int num_trials_per_iter;
/*! \brief The number of total trials. */
int max_trials_per_task;
+ /*! \brief The max number of failures during trace replaying. */
+ int max_fail_count;
/*! \brief The tuning context of the search strategy. */
const TuneContextNode* context_{nullptr};
@@ -71,6 +73,7 @@ class ReplayTraceNode : public SearchStrategyNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_trials_per_iter", &num_trials_per_iter);
v->Visit("max_trials_per_task", &max_trials_per_task);
+ v->Visit("max_fail_count", &max_fail_count);
// `context_` is not visited.
// `rand_state_` is not visited
// `state_` is not visited
@@ -136,7 +139,8 @@ inline Optional<Array<MeasureCandidate>>
ReplayTraceNode::State::GenerateMeasure
int
task_id) -> void {
TRandState& rand_state = per_thread_rand_state[thread_id];
IRModule mod = this->per_thread_mod_[thread_id];
- for (;;) {
+
+ for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) {
int design_space_index = tir::SampleInt(&rand_state, 0,
design_spaces.size());
tir::Trace trace = design_spaces[design_space_index];
tir::Trace new_trace = tir::Trace(trace->insts, {});
@@ -147,7 +151,13 @@ inline Optional<Array<MeasureCandidate>>
ReplayTraceNode::State::GenerateMeasure
}
};
support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker);
- return per_task_result;
+ Array<MeasureCandidate> filtered;
+ filtered.reserve(ed - st);
+ for (MeasureCandidate result : per_task_result)
+ if (result.defined()) {
+ filtered.push_back(result);
+ }
+ return filtered;
}
inline void ReplayTraceNode::State::NotifyRunnerResults(const
Array<RunnerResult>& results) {
@@ -155,10 +165,12 @@ inline void
ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult
ed += self->num_trials_per_iter;
}
-SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int
max_trials_per_task) {
+SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int
max_trials_per_task,
+ int max_fail_count) {
ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
n->num_trials_per_iter = num_trials_per_iter;
n->max_trials_per_task = max_trials_per_task;
+ n->max_fail_count = max_fail_count;
return SearchStrategy(n);
}