comaniac commented on a change in pull request #10366: URL: https://github.com/apache/tvm/pull/10366#discussion_r815189229
########## File path: python/tvm/meta_schedule/task_scheduler/gradient_based.py ########## @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Gradient Based Task Scheduler""" +import math + +from typing import TYPE_CHECKING, List, Optional +from tvm._ffi import register_object +from tvm._ffi.registry import register_func + +from tvm.ir import IRModule +from ..measure_callback import MeasureCallback +from ..builder import Builder +from ..runner import Runner +from ..database import Database +from ..cost_model import CostModel +from .task_scheduler import TaskScheduler + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_func("meta_schedule.task_scheduler.derive_similarity_tag") +def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str: + """Get the tags for smilarity group creation + + Parameters + --------- + mod : IRModule + The input workload. + log_base : float + The log base to normalize the flop count. Review comment: Out of curiosity, why log base matters? It seems to me that you only use it for FLOP (to avoid overflow)? ```suggestion The log base to normalize the flop count. Default natural (1.618). ``` ########## File path: python/tvm/meta_schedule/task_scheduler/gradient_based.py ########## @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Gradient Based Task Scheduler""" +import math + +from typing import TYPE_CHECKING, List, Optional +from tvm._ffi import register_object +from tvm._ffi.registry import register_func + +from tvm.ir import IRModule +from ..measure_callback import MeasureCallback +from ..builder import Builder +from ..runner import Runner +from ..database import Database +from ..cost_model import CostModel +from .task_scheduler import TaskScheduler + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_func("meta_schedule.task_scheduler.derive_similarity_tag") +def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str: + """Get the tags for smilarity group creation + + Parameters + --------- Review comment: ```suggestion ---------- ``` ########## File path: python/tvm/meta_schedule/task_scheduler/gradient_based.py ########## @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Gradient Based Task Scheduler""" +import math + +from typing import TYPE_CHECKING, List, Optional +from tvm._ffi import register_object +from tvm._ffi.registry import register_func + +from tvm.ir import IRModule +from ..measure_callback import MeasureCallback +from ..builder import Builder +from ..runner import Runner +from ..database import Database +from ..cost_model import CostModel +from .task_scheduler import TaskScheduler + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_func("meta_schedule.task_scheduler.derive_similarity_tag") +def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str: + """Get the tags for smilarity group creation + + Parameters + --------- + mod : IRModule + The input workload. + log_base : float + The log base to normalize the flop count. + + Return + ------ + tag : str + The generated similarity tag. + """ + ret = "" + for var in mod.get_global_vars(): + Review comment: remove this empty line. ########## File path: src/meta_schedule/task_scheduler/gradient_based.cc ########## @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The gradient based task scheduler. */ +class GradientBasedNode final : public TaskSchedulerNode { + public: + int backward_window_size; + double alpha, beta; + + bool done_round_robin; // whether the warm up round robin has been done + int task_id = -1; // The current task id processed. + support::LinearCongruentialEngine::TRandState rand_state; // the random state + std::vector<int> task_cnts; // task tuning counts + std::vector<double> task_weights; // task weights + std::vector<double> task_best_latencies; // best latency achived by the task + std::vector<double> task_flop_counts; // total flop count of the task + std::vector<std::vector<double>> task_latency_history; // all task latency history + + std::vector<std::string> task_tag; // tag of the task for grouping + std::map<std::string, int> tag_to_group; // map to find the group id given task tag + std::vector<std::set<int>> task_groups; // the task ids in a given group + + TaskSchedulerNode::FObjectiveFunc objective_func; // the objective function + + void VisitAttrs(tvm::AttrVisitor* v) { + TaskSchedulerNode::VisitAttrs(v); + v->Visit("task_id", &task_id); + } + + static constexpr const char* _type_key = "meta_schedule.GradientBased"; + TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); + + protected: + double _compute_score(const std::vector<double>& latencies) { + Array<FloatImm> input_latencies; + for (double latency : latencies) + input_latencies.push_back(FloatImm(DataType::Float(32), latency)); + return objective_func(input_latencies)->value; + } + + void _adjust_similarity_group(int task_id) { + int group_id = tag_to_group[task_tag[task_id]]; + if (task_groups[group_id].size() <= 1 || + task_groups[group_id].find(task_id) == task_groups[group_id].end()) + return; + + double best_flops = -1.0; + int max_ct[3] = {-1, -1, -1}; // to find the 2nd largest + for (int i : task_groups[group_id]) { + best_flops = std::max(best_flops, task_flop_counts[i] / task_best_latencies[i]); + max_ct[0] = task_cnts[i]; + std::sort(max_ct, max_ct + 3); + } + double cur_flops = task_flop_counts[task_id] / task_best_latencies[task_id]; + // if we tune a task for many times but it still cannot achieve + // a similar speed to the fastest one in its group, this means this task + // is actually not similar to other tasks in its group. + // So we will remove it from its original group. + if (cur_flops < best_flops / beta && task_cnts[task_id] > 5 + max_ct[1]) { + task_groups[group_id].erase(task_id); + } Review comment: What happens afterward? Will this task join another group? Also should `tag_to_group` be updated as well? ########## File path: python/tvm/meta_schedule/task_scheduler/gradient_based.py ########## @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Gradient Based Task Scheduler""" +import math + +from typing import TYPE_CHECKING, List, Optional +from tvm._ffi import register_object +from tvm._ffi.registry import register_func + +from tvm.ir import IRModule +from ..measure_callback import MeasureCallback +from ..builder import Builder +from ..runner import Runner +from ..database import Database +from ..cost_model import CostModel +from .task_scheduler import TaskScheduler + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_func("meta_schedule.task_scheduler.derive_similarity_tag") +def derive_similarity_tag(mod: IRModule, log_base: float = 1.618) -> str: + """Get the tags for smilarity group creation + + Parameters + --------- + mod : IRModule + The input workload. + log_base : float + The log base to normalize the flop count. + + Return + ------ + tag : str + The generated similarity tag. + """ + ret = "" + for var in mod.get_global_vars(): + + if "meta_scheduler_task_scheduler_tag" in mod[var].attrs: + ret += mod[var].attrs.meta_scheduler_task_scheduler_tag + "_" + if ret: + flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member + ret += "%d" % int(math.log(flop_count + 1, log_base)) + return ret + + +@register_object("meta_schedule.GradientBased") +class GradientBased(TaskScheduler): + """Gradient Based Task Scheduler""" + + def __init__( + self, + tasks: List["TuneContext"], + builder: Builder, + runner: Runner, + database: Database, + *, + alpha: float = 0.2, + beta: float = 2.0, + backward_window_size: int = 3, + seed: int = -1, + task_weights: List[float] = None, + objective_func_name: str = "meta_schedule.task_scheduler.objective_func", + tag_generation_func_name: str = "meta_schedule.task_scheduler.derive_similarity_tag", + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + ) -> None: + """Constructor. + + Parameters + ---------- + tasks : List[TuneContext] + List of tasks to schedule. + builder : Builder + The builder. + runner : Runner + The runner. + database : Database + The database. + alpha: float + The parameter alpha to control gradient computation. + beta: float + The parameter beta to control gradient computation. + backward_window_size: int + The parameter to control backward window size. + seed: int + The random seed. Review comment: Please specify the default values. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
