This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 715f24d238 [Metaschedule] Enable continuing tuning after schedule
application failure (#10937)
715f24d238 is described below
commit 715f24d2381d6dd9ce016f7214fe994a574fb358
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Apr 15 16:29:36 2022 +0900
[Metaschedule] Enable continuing tuning after schedule application failure
(#10937)
Currently, when there is a failure in schedule application during tuning
(e.g. tensorize), the entire tuning session is killed with an error msg like
`RuntimeError: parallel_for_dynamic error with ...`. We should gracefully
handle such errors and let tuning continue on other candidates.
No test is added since I don't know how to get tuning to fail in a
controlled manner.
---
src/meta_schedule/task_scheduler/task_scheduler.cc | 1 +
src/meta_schedule/utils.h | 12 ++++++++++--
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc
b/src/meta_schedule/task_scheduler/task_scheduler.cc
index e30295fd1a..cd287fc1d4 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -34,6 +34,7 @@ void SendToBuilder(const Builder& builder, const TuneContext&
context) {
Array<BuilderInput> inputs;
inputs.reserve(candidates.size());
for (const MeasureCandidate& candidate : candidates) {
+ ICHECK(candidate.defined()) << "Undefined MeasureCandidate found";
inputs.push_back(BuilderInput(candidate->sch->mod(), target));
}
context->builder_results = builder->Build(inputs);
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 45a04958ad..a29f991cbb 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -36,6 +36,7 @@
#include <tvm/meta_schedule/tune_context.h>
#include <tvm/node/node.h>
#include <tvm/node/serialization.h>
+#include <tvm/runtime/container/optional.h>
#include <tvm/support/parallel_for.h>
#include <tvm/tir/schedule/schedule.h>
@@ -308,12 +309,19 @@ struct ThreadedTraceApply {
/*rand_state=*/ForkSeed(rand_state),
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
+
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
+
for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
- if (!item.postproc->Apply(sch)) {
- ++item.fail_counter;
+ try {
+ if (!item.postproc->Apply(sch)) {
+ ++item.fail_counter;
+ return NullOpt;
+ }
+ } catch (const std::exception& e) {
+ LOG(WARNING) << "ThreadedTraceApply::Apply failed with error " <<
e.what();
return NullOpt;
}
}