comaniac commented on a change in pull request #10366:
URL: https://github.com/apache/tvm/pull/10366#discussion_r816456183
##########
File path: src/meta_schedule/task_scheduler/gradient_based.cc
##########
@@ -24,66 +24,80 @@ namespace meta_schedule {
/*! \brief The gradient based task scheduler. */
class GradientBasedNode final : public TaskSchedulerNode {
public:
- int backward_window_size;
- double alpha, beta;
+ int backward_window_size; // the backward windows size for backward
gradient computation
+ double alpha, beta; // alpha, beta as in gradient computation
+ bool done_round_robin; // whether the round-robin warm up has been done
+
+ int task_id = -1; // the current task id processing
- bool done_round_robin; // whether the warm up round robin has been done
- int task_id = -1; // The current task id processed.
support::LinearCongruentialEngine::TRandState rand_state; // the random
state
std::vector<int> task_cnts; // task tuning
counts
std::vector<double> task_weights; // task weights
std::vector<double> task_best_latencies; // best latency
achived by the task
std::vector<double> task_flop_counts; // total flop
count of the task
std::vector<std::vector<double>> task_latency_history; // all task
latency history
- std::vector<std::string> task_tag; // tag of the task for grouping
- std::map<std::string, int> tag_to_group; // map to find the group id given
task tag
- std::vector<std::set<int>> task_groups; // the task ids in a given group
+ std::vector<int> task_group_id; // group id of the task
+ std::vector<std::set<int>> task_groups; // the task ids in a given group
TaskSchedulerNode::FObjectiveFunc objective_func; // the objective function
void VisitAttrs(tvm::AttrVisitor* v) {
TaskSchedulerNode::VisitAttrs(v);
+ v->Visit("backward_window_size", &backward_window_size);
+ v->Visit("alpha", &alpha);
+ v->Visit("beta", &beta);
v->Visit("task_id", &task_id);
}
static constexpr const char* _type_key = "meta_schedule.GradientBased";
TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode);
protected:
+ /*!
+ * \brief Compute the objective function score for given workload latencies.
+ * \param latencies The current best latencies of each workload.
+ * \return The computed objective function score.
+ */
double _compute_score(const std::vector<double>& latencies) {
Array<FloatImm> input_latencies;
for (double latency : latencies)
input_latencies.push_back(FloatImm(DataType::Float(32), latency));
return objective_func(input_latencies)->value;
}
+ /*!
+ * \brief Adjuest the similarity group information for given task.
+ * \param task_id The id of the task to adjust similarity group.
+ */
void _adjust_similarity_group(int task_id) {
- int group_id = tag_to_group[task_tag[task_id]];
- if (task_groups[group_id].size() <= 1 ||
- task_groups[group_id].find(task_id) == task_groups[group_id].end())
- return;
+ int group_id = task_group_id[task_id];
+ if (task_groups[group_id].size() <= 1) return;
double best_flops = -1.0;
- int max_ct[3] = {-1, -1, -1}; // to find the 2nd largest
+ int max_cnt[3] = {-1, -1, -1}; // to find the 2nd largest
for (int i : task_groups[group_id]) {
best_flops = std::max(best_flops, task_flop_counts[i] /
task_best_latencies[i]);
- max_ct[0] = task_cnts[i];
- std::sort(max_ct, max_ct + 3);
+ max_cnt[0] = task_cnts[i];
+ std::sort(max_cnt, max_cnt + 3); // place the 2nd largest to #1
}
double cur_flops = task_flop_counts[task_id] /
task_best_latencies[task_id];
// if we tune a task for many times but it still cannot achieve
// a similar speed to the fastest one in its group, this means this task
// is actually not similar to other tasks in its group.
// So we will remove it from its original group.
Review comment:
Suggestion: So we will move it from the current group to a standalone
group.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]