junrushao commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r985191082
##########
src/meta_schedule/task_scheduler/task_scheduler.cc:
##########
@@ -21,83 +21,225 @@
namespace tvm {
namespace meta_schedule {
-void TaskSchedulerNode::InitializeTask(int task_id) {
+TaskRecord::TaskRecord(TuneContext ctx, double task_weight) {
+ ObjectPtr<TaskRecordNode> n = runtime::make_object<TaskRecordNode>();
+ n->ctx = ctx;
+ n->task_weight = task_weight;
+ n->flop = 1.0;
auto _ = Profiler::TimedScope("InitializeTask");
- TuneContext task = this->tasks[task_id];
- TVM_PY_LOG(INFO, this->logging_func)
- << "Initializing Task #" << task_id << ": " << task->task_name;
- TVM_PY_LOG(INFO, task->logging_func)
- << "Initializing Task #" << task_id << ": " << task->task_name;
- CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is
not defined";
- CHECK(task->space_generator.defined())
+ CHECK(ctx->mod.defined()) << "ValueError: Require `context.mod`, but it is
not defined";
+ CHECK(ctx->space_generator.defined())
<< "ValueError: Require `context.space_generator`, but it is not
defined";
- CHECK(task->search_strategy.defined())
+ CHECK(ctx->search_strategy.defined())
<< "ValueError: Require `context.search_strategy`, but it is not
defined";
- TVM_PY_LOG(INFO, task->logging_func) << "\n" << tir::AsTVMScript(task->mod);
- task->Initialize();
- Array<tir::Schedule> design_spaces =
- task->space_generator.value()->GenerateDesignSpace(task->mod.value());
- TVM_PY_LOG(INFO, task->logging_func)
- << "Total " << design_spaces.size() << " design space(s) generated";
- for (int i = 0, n = design_spaces.size(); i < n; ++i) {
- tir::Schedule sch = design_spaces[i];
- tir::Trace trace = sch->trace().value();
- trace = trace->Simplified(true);
- TVM_PY_LOG(INFO, task->logging_func) << "Design space #" << i << ":\n"
- << tir::AsTVMScript(sch->mod()) <<
"\n"
- << Concat(trace->AsPython(false),
"\n");
+ TVM_PY_LOG(INFO, ctx->logger) << "\n" << tir::AsTVMScript(ctx->mod);
+ ctx->Initialize();
+ n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value()));
+ this->data_ = std::move(n);
+}
+
+void SendToBuilder(TaskRecordNode* self, const Builder& builder) {
+ auto _ = Profiler::TimedScope("SendToBuilder");
+ Array<MeasureCandidate> candidates = self->measure_candidates.value();
+ Target target = self->ctx->target.value();
+ Array<BuilderInput> inputs;
+ inputs.reserve(candidates.size());
+ for (const MeasureCandidate& candidate : candidates) {
+ inputs.push_back(BuilderInput(candidate->sch->mod(), target));
}
- task->search_strategy.value()->PreTuning(design_spaces, database,
cost_model);
+ self->builder_results = builder->Build(inputs);
}
-void TaskSchedulerNode::Tune() {
- int n_tasks = this->tasks.size();
- for (int task_id = 0; task_id < n_tasks; ++task_id) {
- InitializeTask(task_id);
+void SendToRunner(TaskRecordNode* self, const Runner& runner) {
+ auto _ = Profiler::TimedScope("SendToRunner");
+ Array<MeasureCandidate> candidates = self->measure_candidates.value();
+ Array<BuilderResult> builder_results = self->builder_results.value();
+ Target target = self->ctx->target.value();
+ ICHECK_EQ(candidates.size(), builder_results.size());
+ int n = candidates.size();
+ int n_build_errors = 0;
+ Array<RunnerInput> inputs;
+ inputs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ const MeasureCandidate& candidate = candidates[i];
+ const BuilderResult& builder_result = builder_results[i];
+ if (builder_result->error_msg.defined()) {
+ ++n_build_errors;
+ continue;
+ }
+
inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(),
+ /*device_type=*/target->kind->name,
+ /*args_info=*/candidate->args_info));
+ }
+ Array<RunnerFuture> futures = runner->Run(inputs);
+ if (n_build_errors == 0) {
+ self->runner_futures = futures;
+ return;
+ }
+ Array<RunnerFuture> results;
+ results.reserve(n);
+ for (int i = 0, j = 0; i < n; ++i) {
+ const BuilderResult& builder_result = builder_results[i];
+ if (builder_result->error_msg.defined()) {
+ results.push_back(RunnerFuture(
+ /*f_done=*/[]() -> bool { return true; },
+ /*f_result=*/
+ [msg = builder_result->error_msg]() -> RunnerResult {
+ return RunnerResult(NullOpt, msg);
+ }));
+ } else {
+ results.push_back(futures[j++]);
+ }
+ }
+ self->runner_futures = results;
+}
+
+void TaskCleanUp(TaskRecordNode* self, int task_id, const Array<RunnerResult>&
results) {
+ ICHECK_EQ(self->builder_results.value().size(), results.size());
+ ICHECK_EQ(self->runner_futures.value().size(), results.size());
+ int n = results.size();
+ std::string name = self->ctx->task_name.value();
+ const PackedFunc& logger = self->ctx->logger;
+ for (int i = 0; i < n; ++i) {
+ const BuilderResult& builder_result = self->builder_results.value()[i];
+ const MeasureCandidate& candidate = self->measure_candidates.value()[i];
+ const RunnerResult& runner_result = results[i];
+ Optional<String> error_msg = NullOpt;
+ int trials = self->latency_ms.size() + 1;
+ double run_ms = 1e9;
+ if ((error_msg = builder_result->error_msg)) {
+ ++self->build_error_count;
+ } else if ((error_msg = runner_result->error_msg)) {
+ ++self->run_error_count;
+ } else {
+ run_ms = GetRunMsMedian(runner_result);
+ }
+ self->latency_ms.push_back(run_ms);
+ if (error_msg) {
Review Comment:
It could be pretty frequent if something goes wrong, but normally it didn't
happen often. By default, the error message is only sent to a per-task logging
file rather than the main screen, so I wouldn't worry too much about it
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]