comaniac commented on a change in pull request #10366:
URL: https://github.com/apache/tvm/pull/10366#discussion_r816456183



##########
File path: src/meta_schedule/task_scheduler/gradient_based.cc
##########
@@ -24,66 +24,80 @@ namespace meta_schedule {
 /*! \brief The gradient based task scheduler. */
 class GradientBasedNode final : public TaskSchedulerNode {
  public:
-  int backward_window_size;
-  double alpha, beta;
+  int backward_window_size;  // the backward windows size for backward 
gradient computation
+  double alpha, beta;        // alpha, beta as in gradient computation
+  bool done_round_robin;     // whether the round-robin warm up has been done
+
+  int task_id = -1;  // the current task id processing
 
-  bool done_round_robin;  // whether the warm up round robin has been done
-  int task_id = -1;       // The current task id processed.
   support::LinearCongruentialEngine::TRandState rand_state;  // the random 
state
   std::vector<int> task_cnts;                                // task tuning 
counts
   std::vector<double> task_weights;                          // task weights
   std::vector<double> task_best_latencies;                   // best latency 
achived by the task
   std::vector<double> task_flop_counts;                      // total flop 
count of the task
   std::vector<std::vector<double>> task_latency_history;     // all task 
latency history
 
-  std::vector<std::string> task_tag;        // tag of the task for grouping
-  std::map<std::string, int> tag_to_group;  // map to find the group id given 
task tag
-  std::vector<std::set<int>> task_groups;   // the task ids in a given group
+  std::vector<int> task_group_id;          // group id of the task
+  std::vector<std::set<int>> task_groups;  // the task ids in a given group
 
   TaskSchedulerNode::FObjectiveFunc objective_func;  // the objective function
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     TaskSchedulerNode::VisitAttrs(v);
+    v->Visit("backward_window_size", &backward_window_size);
+    v->Visit("alpha", &alpha);
+    v->Visit("beta", &beta);
     v->Visit("task_id", &task_id);
   }
 
   static constexpr const char* _type_key = "meta_schedule.GradientBased";
   TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode);
 
  protected:
+  /*!
+   * \brief Compute the objective function score for given workload latencies.
+   * \param latencies The current best latencies of each workload.
+   * \return The computed objective function score.
+   */
   double _compute_score(const std::vector<double>& latencies) {
     Array<FloatImm> input_latencies;
     for (double latency : latencies)
       input_latencies.push_back(FloatImm(DataType::Float(32), latency));
     return objective_func(input_latencies)->value;
   }
 
+  /*!
+   * \brief Adjuest the similarity group information for given task.
+   * \param task_id The id of the task to adjust similarity group.
+   */
   void _adjust_similarity_group(int task_id) {
-    int group_id = tag_to_group[task_tag[task_id]];
-    if (task_groups[group_id].size() <= 1 ||
-        task_groups[group_id].find(task_id) == task_groups[group_id].end())
-      return;
+    int group_id = task_group_id[task_id];
+    if (task_groups[group_id].size() <= 1) return;
 
     double best_flops = -1.0;
-    int max_ct[3] = {-1, -1, -1};  // to find the 2nd largest
+    int max_cnt[3] = {-1, -1, -1};  // to find the 2nd largest
     for (int i : task_groups[group_id]) {
       best_flops = std::max(best_flops, task_flop_counts[i] / 
task_best_latencies[i]);
-      max_ct[0] = task_cnts[i];
-      std::sort(max_ct, max_ct + 3);
+      max_cnt[0] = task_cnts[i];
+      std::sort(max_cnt, max_cnt + 3);  // place the 2nd largest to #1
     }
     double cur_flops = task_flop_counts[task_id] / 
task_best_latencies[task_id];
     // if we tune a task for many times but it still cannot achieve
     // a similar speed to the fastest one in its group, this means this task
     // is actually not similar to other tasks in its group.
     // So we will remove it from its original group.

Review comment:
       Suggestion: So we will move it from the current group to a standalone 
group.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to