zxybazh commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r984247929
##########
include/tvm/meta_schedule/search_strategy.h:
##########
@@ -257,8 +249,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
}
void InitializeWithTuneContext(const TuneContext& context) final;
- void PreTuning(const Array<tir::Schedule>& design_spaces, const
Optional<Database>& database,
- const Optional<CostModel>& cost_model) final;
+ void PreTuning(int max_trials, int num_trials_per_iter, const
Array<tir::Schedule>& design_spaces,
Review Comment:
May I ask how `max_trials_per_task` is used in the new interface?
##########
include/tvm/meta_schedule/task_scheduler.h:
##########
@@ -143,55 +202,43 @@ class TaskScheduler;
/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
- /*! \brief The function type of `Tune` method. */
- using FTune = runtime::TypedPackedFunc<void()>;
-
- /*! \brief The function type of `InitializeTask` method. */
- using FInitializeTask = runtime::TypedPackedFunc<void(int)>;
-
/*!
- * \brief The function type of `TouchTask` method.
- * \param task_id The task id to be checked.
- * \return Whether the task is running.
+ * \brief The function type of `NextTaskId` method.
+ * \return The next task id.
*/
- using FTouchTask = runtime::TypedPackedFunc<void(int)>;
-
+ using FNextTaskId = runtime::TypedPackedFunc<int()>;
/*!
* \brief The function type of `JoinRunningTask` method.
* \param task_id The task id to be joined.
*/
using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;
+ /*! \brief The function type of `Tune` method. */
+ using FTune = runtime::TypedPackedFunc<void(
Review Comment:
Ditto. Format.
##########
python/tvm/meta_schedule/cost_model/cost_model.py:
##########
@@ -16,23 +16,30 @@
# under the License.
"""Meta Schedule CostModel."""
import ctypes
-from typing import Callable, List
+from typing import Callable, List, Union
+
+# isort: off
+from typing_extensions import Literal
+
+# isort: on
import numpy as np # type: ignore
from tvm._ffi import register_object
-from tvm.meta_schedule.utils import _get_default_str
from tvm.runtime import Object
from .. import _ffi_api
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
+from ..utils import _get_default_str
@register_object("meta_schedule.CostModel")
class CostModel(Object):
"""Cost model."""
+ CostModelType = Union["CostModel", Literal["xgb", "mlp", "random"]]
Review Comment:
Maybe we need a note/comment somewhere to remind developer to update the
list here when some new cost model is created.
##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -15,28 +15,82 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
-from typing import Any, Dict, List, Optional
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+# isort: off
Review Comment:
This could be put behind the import if there is only one import that we
don't want to sort. Btw why do we put isort here for this import?
##########
include/tvm/support/random_engine.h:
##########
@@ -46,32 +43,18 @@ namespace support {
class LinearCongruentialEngine {
public:
- /*!
- * \brief The result type is defined as uint64_t here to avoid overflow.
- * \note The type name is not in Google style because it is used in STL's
distribution inferface.
Review Comment:
Suggest to keep for better code understanding.
##########
include/tvm/meta_schedule/task_scheduler.h:
##########
@@ -73,66 +127,71 @@ namespace meta_schedule {
*/
class TaskSchedulerNode : public runtime::Object {
public:
- /*! \brief The tasks to be tuned */
- Array<TuneContext> tasks;
- /*! \brief The builder of the scheduler. */
- Builder builder{nullptr};
- /*! \brief The runner of the scheduler. */
- Runner runner{nullptr};
- /*! \brief The database of the scheduler. */
- Optional<Database> database;
- /*! \brief The cost model of the scheduler. */
- Optional<CostModel> cost_model;
+ /*! \brief The tuning task's logging function. */
+ PackedFunc logger;
+ /*! \brief Records for each task */
+ Array<TaskRecord> tasks_;
/*! \brief The list of measure callbacks of the scheduler. */
- Array<MeasureCallback> measure_callbacks;
- /*! \brief The maximum number of trials allowed. */
- int max_trials;
- /*! \brief The number of trials already conducted. */
- int num_trials_already;
- /*! \brief The tuning task's logging function. t*/
- PackedFunc logging_func;
+ Array<MeasureCallback> measure_callbacks_;
+ /*! \brief The database used in tuning */
+ Optional<Database> database_;
+ /*! \brief The cost model used in tuning */
+ Optional<CostModel> cost_model_;
+ /*! \brief The number of remaining tasks to be tuned. */
+ int remaining_tasks_;
/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("tasks", &tasks);
- v->Visit("builder", &builder);
- v->Visit("runner", &runner);
- v->Visit("database", &database);
- v->Visit("cost_model", &cost_model);
- v->Visit("measure_callbacks", &measure_callbacks);
- v->Visit("max_trials", &max_trials);
- v->Visit("num_trials_already", &num_trials_already);
- // `logging_func` is not visited
+ // `logger` is not visited
+ v->Visit("tasks_", &tasks_);
+ v->Visit("measure_callbacks_", &measure_callbacks_);
+ v->Visit("database_", &database_);
+ v->Visit("cost_model_", &cost_model_);
+ v->Visit("remaining_tasks_", &remaining_tasks_);
}
- /*! \brief Auto-tuning. */
- virtual void Tune();
-
- /*!
- * \brief Initialize modules of the given task.
- * \param task_id The task id to be initialized.
- */
- virtual void InitializeTask(int task_id);
-
/*!
- * \brief Touch the task and update its status
- * \param task_id The task id to be checked.
+ * \brief Fetch the next task id.
+ * \return The next task id.
*/
- virtual void TouchTask(int task_id);
-
+ virtual int NextTaskId() = 0;
/*!
* \brief Wait until the task is finished.
* \param task_id The task id to be joined.
+ * \return The results from the runner.
*/
virtual Array<RunnerResult> JoinRunningTask(int task_id);
-
/*!
- * \brief Fetch the next task id.
- * \return The next task id.
+ * \brief Jointly tune a given list of tasks.
+ * \param tasks The tasks to be tuned
+ * \param task_weights The weight of each task
+ * \param max_trials_global The maximum number of trials to be performed
globally
+ * \param max_trials_per_task The maximum number of trials to be performed
for each task
+ * \param num_trials_per_iter The number of trials to be performed in each
iteration
+ * \param builder The MetaSchedule builder
+ * \param runner The MetaSchedule runner
+ * \param measure_callbacks The callbacks to be called after each measurement
+ * \param database The database used in tuning
+ * \param cost_model The cost model used in tuning
*/
- virtual int NextTaskId() = 0;
+ virtual void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights,
int max_trials_global,
Review Comment:
Can you line these arguments up for better code formatting?
##########
python/tvm/meta_schedule/logging.py:
##########
@@ -0,0 +1,262 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Logging interface in MetaSchedule"""
+import logging
+import logging.config
+import os
+import os.path as osp
+from logging import Logger
+from typing import Any, Callable, Dict, List, Optional
+
+
+def get_logger(name: str) -> Logger:
+ """Create or get a logger by its name. This is essentially a wrapper of
python's native logger.
+
+ Parameters
+ ----------
+ name : str
+ The name of the logger.
+
+ Returns
+ -------
+ logger : Logger
+ The logger instance.
+ """
+ return logging.getLogger(name)
+
+
+def make_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]:
Review Comment:
`get_logging_func` sounds more reasonable.
##########
include/tvm/meta_schedule/space_generator.h:
##########
@@ -171,6 +213,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
FClone f_clone;
void VisitAttrs(tvm::AttrVisitor* v) {
+ SpaceGeneratorNode::VisitAttrs(v);
Review Comment:
Is this expected for all derived classes?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]