zxybazh commented on a change in pull request #10366: URL: https://github.com/apache/tvm/pull/10366#discussion_r816411413
########## File path: src/meta_schedule/task_scheduler/gradient_based.cc ########## @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The gradient based task scheduler. */ +class GradientBasedNode final : public TaskSchedulerNode { + public: + int backward_window_size; + double alpha, beta; + + bool done_round_robin; // whether the warm up round robin has been done + int task_id = -1; // The current task id processed. + support::LinearCongruentialEngine::TRandState rand_state; // the random state + std::vector<int> task_cnts; // task tuning counts + std::vector<double> task_weights; // task weights + std::vector<double> task_best_latencies; // best latency achived by the task + std::vector<double> task_flop_counts; // total flop count of the task + std::vector<std::vector<double>> task_latency_history; // all task latency history + + std::vector<std::string> task_tag; // tag of the task for grouping + std::map<std::string, int> tag_to_group; // map to find the group id given task tag + std::vector<std::set<int>> task_groups; // the task ids in a given group + + TaskSchedulerNode::FObjectiveFunc objective_func; // the objective function + + void VisitAttrs(tvm::AttrVisitor* v) { + TaskSchedulerNode::VisitAttrs(v); + v->Visit("task_id", &task_id); + } + + static constexpr const char* _type_key = "meta_schedule.GradientBased"; + TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); + + protected: + double _compute_score(const std::vector<double>& latencies) { + Array<FloatImm> input_latencies; + for (double latency : latencies) + input_latencies.push_back(FloatImm(DataType::Float(32), latency)); + return objective_func(input_latencies)->value; + } + + void _adjust_similarity_group(int task_id) { + int group_id = tag_to_group[task_tag[task_id]]; + if (task_groups[group_id].size() <= 1 || + task_groups[group_id].find(task_id) == task_groups[group_id].end()) + return; + + double best_flops = -1.0; + int max_ct[3] = {-1, -1, -1}; // to find the 2nd largest + for (int i : task_groups[group_id]) { + best_flops = std::max(best_flops, task_flop_counts[i] / task_best_latencies[i]); + max_ct[0] = task_cnts[i]; + std::sort(max_ct, max_ct + 3); + } + double cur_flops = task_flop_counts[task_id] / task_best_latencies[task_id]; + // if we tune a task for many times but it still cannot achieve + // a similar speed to the fastest one in its group, this means this task + // is actually not similar to other tasks in its group. + // So we will remove it from its original group. + if (cur_flops < best_flops / beta && task_cnts[task_id] > 5 + max_ct[1]) { + task_groups[group_id].erase(task_id); + } Review comment: Thanks for pointing it out, originally I checked a couple lines above that the groups won't be changed if the group set won't find the task. I have rewritten the logic to create a single element group afterwards. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
