This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 71f32ca4e8 [MetaSchedule][UX] Support Interactive Performance Table
Printing in Notebook (#13006)
71f32ca4e8 is described below
commit 71f32ca4e8e6f33da55cd6c39c5019caadcdc78e
Author: Xiyou Zhou <[email protected]>
AuthorDate: Fri Oct 14 20:24:50 2022 -0700
[MetaSchedule][UX] Support Interactive Performance Table Printing in
Notebook (#13006)
* Support interactive table printing.
* Rebase.
* Fix jupyter outputs.
* Fix CI.
* fix CI.
* Change file to filename.
* Address issues.
---
include/tvm/meta_schedule/task_scheduler.h | 4 +--
python/tvm/meta_schedule/logging.py | 24 +++++++++------
.../meta_schedule/task_scheduler/task_scheduler.py | 12 ++------
python/tvm/meta_schedule/utils.py | 34 +++++++++++++++++++-
src/meta_schedule/task_scheduler/gradient_based.cc | 3 +-
src/meta_schedule/task_scheduler/task_scheduler.cc | 29 +++++++++--------
src/meta_schedule/utils.h | 36 +++++++++++++---------
7 files changed, 93 insertions(+), 49 deletions(-)
diff --git a/include/tvm/meta_schedule/task_scheduler.h
b/include/tvm/meta_schedule/task_scheduler.h
index 17d82558fb..f4fc491286 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -196,8 +196,8 @@ class TaskSchedulerNode : public runtime::Object {
* \param task_id The task id to be checked.
*/
void TouchTask(int task_id);
- /*! \brief Returns a human-readable string of the tuning statistics. */
- std::string TuningStatistics() const;
+ /*! \brief Print out a human-readable format of the tuning statistics. */
+ void PrintTuningStatistics();
static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
diff --git a/python/tvm/meta_schedule/logging.py
b/python/tvm/meta_schedule/logging.py
index 9d673266a3..53353e3aa9 100644
--- a/python/tvm/meta_schedule/logging.py
+++ b/python/tvm/meta_schedule/logging.py
@@ -39,7 +39,7 @@ def get_logger(name: str) -> Logger:
return logging.getLogger(name)
-def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]:
+def get_logging_func(logger: Logger) -> Optional[Callable[[int, str, int,
str], None]]:
"""Get the logging function.
Parameters
@@ -62,15 +62,15 @@ def get_logging_func(logger: Logger) ->
Optional[Callable[[int, str], None]]:
# logging.FATAL not included
}
- def logging_func(level: int, msg: str):
- if level < 0:
+ def logging_func(level: int, filename: str, lineo: int, msg: str):
+ if level < 0: # clear the output in notebook / console
from IPython.display import ( # type: ignore # pylint:
disable=import-outside-toplevel
clear_output,
)
clear_output(wait=True)
else:
- level2log[level](msg)
+ level2log[level](f"[{os.path.basename(filename)}:{lineo}] " + msg)
return logging_func
@@ -94,12 +94,15 @@ def create_loggers(
global_logger_name = "tvm.meta_schedule"
global_logger = logging.getLogger(global_logger_name)
if global_logger.level is logging.NOTSET:
- global_logger.setLevel(logging.INFO)
+ global_logger.setLevel(logging.DEBUG)
+ console_logging_level = logging._levelToName[ # pylint:
disable=protected-access
+ global_logger.level
+ ]
config["loggers"].setdefault(
global_logger_name,
{
- "level": logging._levelToName[global_logger.level], # pylint:
disable=protected-access
+ "level": logging.DEBUG,
"handlers": [handler.get_name() for handler in
global_logger.handlers]
+ [global_logger_name + ".console", global_logger_name + ".file"],
"propagate": False,
@@ -108,7 +111,7 @@ def create_loggers(
config["loggers"].setdefault(
"{logger_name}",
{
- "level": "INFO",
+ "level": "DEBUG",
"handlers": [
"{logger_name}.file",
],
@@ -121,6 +124,7 @@ def create_loggers(
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "tvm.meta_schedule.standard_formatter",
+ "level": console_logging_level,
},
)
config["handlers"].setdefault(
@@ -129,7 +133,7 @@ def create_loggers(
"class": "logging.FileHandler",
"filename": "{log_dir}/" + __name__ + ".task_scheduler.log",
"mode": "a",
- "level": "INFO",
+ "level": "DEBUG",
"formatter": "tvm.meta_schedule.standard_formatter",
},
)
@@ -139,14 +143,14 @@ def create_loggers(
"class": "logging.FileHandler",
"filename": "{log_dir}/{logger_name}.log",
"mode": "a",
- "level": "INFO",
+ "level": "DEBUG",
"formatter": "tvm.meta_schedule.standard_formatter",
},
)
config["formatters"].setdefault(
"tvm.meta_schedule.standard_formatter",
{
- "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
+ "format": "%(asctime)s [%(levelname)s] %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
)
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index f06f4d911f..d56d944474 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -163,15 +163,9 @@ class TaskScheduler(Object):
"""
_ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore #
pylint: disable=no-member
- def tuning_statistics(self) -> str:
- """Returns a human-readable string of the tuning statistics.
-
- Returns
- -------
- tuning_statistics : str
- The tuning statistics.
- """
- return _ffi_api.TaskSchedulerTuningStatistics(self) # type: ignore #
pylint: disable=no-member
+ def print_tuning_statistics(self) -> None:
+ """Print out a human-readable format of the tuning statistics."""
+ return _ffi_api.TaskSchedulerPrintTuningStatistics(self) # type:
ignore # pylint: disable=no-member
@staticmethod
def create( # pylint: disable=keyword-arg-before-vararg
diff --git a/python/tvm/meta_schedule/utils.py
b/python/tvm/meta_schedule/utils.py
index eb3c643760..401fdab08a 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -188,13 +188,45 @@ def cpu_count(logical: bool = True) -> int:
@register_func("meta_schedule.using_ipython")
-def _using_ipython():
+def _using_ipython() -> bool:
+ """Return whether the current process is running in an IPython shell.
+
+ Returns
+ -------
+ result : bool
+ Whether the current process is running in an IPython shell.
+ """
try:
return get_ipython().__class__.__name__ == "ZMQInteractiveShell" #
type: ignore
except NameError:
return False
+@register_func("meta_schedule.print_interactive_table")
+def print_interactive_table(data: str) -> None:
+ """Print the dataframe interactive table in notebook.
+
+ Parameters
+ ----------
+ data : str
+ The serialized performance table from MetaSchedule table printer.
+ """
+ import pandas as pd # type: ignore # pylint:
disable=import-outside-toplevel
+ from IPython.display import display # type: ignore # pylint:
disable=import-outside-toplevel
+
+ pd.set_option("display.max_rows", None)
+ pd.set_option("display.max_colwidth", None)
+ parsed = [
+ x.split("|")[1:] for x in list(filter(lambda x: set(x) != {"-"},
data.strip().split("\n")))
+ ]
+ display(
+ pd.DataFrame(
+ parsed[1:],
+ columns=parsed[0],
+ )
+ )
+
+
def get_global_func_with_default_on_worker(
name: Union[None, str, Callable],
default: Callable,
diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc
b/src/meta_schedule/task_scheduler/gradient_based.cc
index bae52573a0..e0470337b5 100644
--- a/src/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/meta_schedule/task_scheduler/gradient_based.cc
@@ -60,7 +60,8 @@ class GradientBasedNode final : public TaskSchedulerNode {
int n_tasks = this->tasks_.size();
// Step 1. Check if it's in round robin mode.
if (round_robin_rounds_ == 0) {
- TVM_PY_LOG(INFO, this->logger) << "\n" << this->TuningStatistics();
+ TVM_PY_LOG_CLEAR_SCREEN(this->logger);
+ this->PrintTuningStatistics();
}
if (round_robin_rounds_ < n_tasks) {
return round_robin_rounds_++;
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc
b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 21efde26d9..69a70f63c5 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -232,9 +232,8 @@ Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int
task_id) {
}
TaskCleanUp(task, task_id, results);
TVM_PY_LOG_CLEAR_SCREEN(this->logger);
- TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " <<
task->ctx->task_name
- << "\n"
- << this->TuningStatistics();
+ TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " <<
task->ctx->task_name;
+ this->PrintTuningStatistics();
return results;
}
@@ -257,12 +256,11 @@ void TaskSchedulerNode::TerminateTask(int task_id) {
--this->remaining_tasks_;
TVM_PY_LOG_CLEAR_SCREEN(this->logger);
TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id
- << " has finished. Remaining task(s): " <<
this->remaining_tasks_
- << "\n"
- << this->TuningStatistics();
+ << " has finished. Remaining task(s): " <<
this->remaining_tasks_;
+ this->PrintTuningStatistics();
}
-std::string TaskSchedulerNode::TuningStatistics() const {
+void TaskSchedulerNode::PrintTuningStatistics() {
std::ostringstream os;
int n_tasks = this->tasks_.size();
int total_trials = 0;
@@ -307,11 +305,18 @@ std::string TaskSchedulerNode::TuningStatistics() const {
}
}
p.Separator();
- os << p.AsStr() //
- << "\nTotal trials: " << total_trials //
+
+ os << "\nTotal trials: " << total_trials //
<< "\nTotal latency (us): " << total_latency //
<< "\n";
- return os.str();
+
+ if (using_ipython()) {
+ print_interactive_table(p.AsStr());
+ std::cout << os.str() << std::endl << std::flush;
+ TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str();
+ } else {
+ TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str();
+ }
}
TaskScheduler TaskScheduler::PyTaskScheduler(
@@ -369,8 +374,8 @@
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::TerminateTask);
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::TouchTask);
-TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTuningStatistics")
- .set_body_method<TaskScheduler>(&TaskSchedulerNode::TuningStatistics);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics")
+ .set_body_method<TaskScheduler>(&TaskSchedulerNode::PrintTuningStatistics);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 41d8ffde55..b14717f4b2 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -82,36 +82,32 @@ class PyLogMessage {
// FATAL not included
};
- explicit PyLogMessage(const char* file, int lineno, PackedFunc logger, Level
logging_level)
- : file_(file), lineno_(lineno), logger_(logger),
logging_level_(logging_level) {
- if (this->logger_ != nullptr) {
- stream_ << "" << file_ << ":" << lineno_ << " ";
- }
- }
+ explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger,
Level logging_level)
+ : filename_(filename), lineno_(lineno), logger_(logger),
logging_level_(logging_level) {}
TVM_NO_INLINE ~PyLogMessage() {
ICHECK(logging_level_ != Level::CLEAR)
<< "Cannot use CLEAR as logging level in TVM_PY_LOG, please use
TVM_PY_LOG_CLEAR_SCREEN.";
if (this->logger_ != nullptr) {
- logger_(static_cast<int>(logging_level_), stream_.str());
+ logger_(static_cast<int>(logging_level_), std::string(filename_),
lineno_, stream_.str());
} else {
if (logging_level_ == Level::INFO) {
- runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
+ runtime::detail::LogMessage(filename_, lineno_).stream() <<
stream_.str();
} else if (logging_level_ == Level::WARNING) {
- runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " <<
stream_.str();
+ runtime::detail::LogMessage(filename_, lineno_).stream() << "Warning:
" << stream_.str();
} else if (logging_level_ == Level::ERROR) {
- runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " <<
stream_.str();
+ runtime::detail::LogMessage(filename_, lineno_).stream() << "Error: "
<< stream_.str();
} else if (logging_level_ == Level::DEBUG) {
- runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " <<
stream_.str();
+ runtime::detail::LogMessage(filename_, lineno_).stream() << "Debug: "
<< stream_.str();
} else {
- runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
+ runtime::detail::LogFatal(filename_, lineno_).stream() <<
stream_.str();
}
}
}
std::ostringstream& stream() { return stream_; }
private:
- const char* file_;
+ const char* filename_;
int lineno_;
std::ostringstream stream_;
PackedFunc logger_;
@@ -131,6 +127,18 @@ inline bool using_ipython() {
return flag;
}
+/*!
+ * \brief Print out the performance table interactively in jupyter notebook.
+ * \param str The serialized performance table.
+ */
+inline void print_interactive_table(const String& data) {
+ const auto* f_print_interactive_table =
+ runtime::Registry::Get("meta_schedule.print_interactive_table");
+ ICHECK(f_print_interactive_table->defined())
+ << "Cannot find print_interactive_table function in registry.";
+ (*f_print_interactive_table)(data);
+}
+
/*!
* \brief A helper function to clear logging output for ipython kernel and
console.
* \param file The file name.
@@ -139,7 +147,7 @@ inline bool using_ipython() {
*/
inline void clear_logging(const char* file, int lineno, PackedFunc
logging_func) {
if (logging_func.defined() && using_ipython()) {
- logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), "");
+ logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno,
"");
} else {
// this would clear all logging output in the console
runtime::detail::LogMessage(file, lineno).stream() <<
"\033c\033[3J\033[2J\033[0m\033[H";