This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 71f32ca4e8 [MetaSchedule][UX] Support Interactive Performance Table 
Printing in Notebook (#13006)
71f32ca4e8 is described below

commit 71f32ca4e8e6f33da55cd6c39c5019caadcdc78e
Author: Xiyou Zhou <[email protected]>
AuthorDate: Fri Oct 14 20:24:50 2022 -0700

    [MetaSchedule][UX] Support Interactive Performance Table Printing in 
Notebook (#13006)
    
    * Support interactive table printing.
    
    * Rebase.
    
    * Fix jupyter outputs.
    
    * Fix CI.
    
    * fix CI.
    
    * Change file to filename.
    
    * Address issues.
---
 include/tvm/meta_schedule/task_scheduler.h         |  4 +--
 python/tvm/meta_schedule/logging.py                | 24 +++++++++------
 .../meta_schedule/task_scheduler/task_scheduler.py | 12 ++------
 python/tvm/meta_schedule/utils.py                  | 34 +++++++++++++++++++-
 src/meta_schedule/task_scheduler/gradient_based.cc |  3 +-
 src/meta_schedule/task_scheduler/task_scheduler.cc | 29 +++++++++--------
 src/meta_schedule/utils.h                          | 36 +++++++++++++---------
 7 files changed, 93 insertions(+), 49 deletions(-)

diff --git a/include/tvm/meta_schedule/task_scheduler.h 
b/include/tvm/meta_schedule/task_scheduler.h
index 17d82558fb..f4fc491286 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -196,8 +196,8 @@ class TaskSchedulerNode : public runtime::Object {
    * \param task_id The task id to be checked.
    */
   void TouchTask(int task_id);
-  /*! \brief Returns a human-readable string of the tuning statistics. */
-  std::string TuningStatistics() const;
+  /*! \brief Print out a human-readable format of the tuning statistics. */
+  void PrintTuningStatistics();
 
   static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
   TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
diff --git a/python/tvm/meta_schedule/logging.py 
b/python/tvm/meta_schedule/logging.py
index 9d673266a3..53353e3aa9 100644
--- a/python/tvm/meta_schedule/logging.py
+++ b/python/tvm/meta_schedule/logging.py
@@ -39,7 +39,7 @@ def get_logger(name: str) -> Logger:
     return logging.getLogger(name)
 
 
-def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]:
+def get_logging_func(logger: Logger) -> Optional[Callable[[int, str, int, 
str], None]]:
     """Get the logging function.
 
     Parameters
@@ -62,15 +62,15 @@ def get_logging_func(logger: Logger) -> 
Optional[Callable[[int, str], None]]:
         # logging.FATAL not included
     }
 
-    def logging_func(level: int, msg: str):
-        if level < 0:
+    def logging_func(level: int, filename: str, lineo: int, msg: str):
+        if level < 0:  # clear the output in notebook / console
             from IPython.display import (  # type: ignore # pylint: 
disable=import-outside-toplevel
                 clear_output,
             )
 
             clear_output(wait=True)
         else:
-            level2log[level](msg)
+            level2log[level](f"[{os.path.basename(filename)}:{lineo}] " + msg)
 
     return logging_func
 
@@ -94,12 +94,15 @@ def create_loggers(
     global_logger_name = "tvm.meta_schedule"
     global_logger = logging.getLogger(global_logger_name)
     if global_logger.level is logging.NOTSET:
-        global_logger.setLevel(logging.INFO)
+        global_logger.setLevel(logging.DEBUG)
+    console_logging_level = logging._levelToName[  # pylint: 
disable=protected-access
+        global_logger.level
+    ]
 
     config["loggers"].setdefault(
         global_logger_name,
         {
-            "level": logging._levelToName[global_logger.level],  # pylint: 
disable=protected-access
+            "level": logging.DEBUG,
             "handlers": [handler.get_name() for handler in 
global_logger.handlers]
             + [global_logger_name + ".console", global_logger_name + ".file"],
             "propagate": False,
@@ -108,7 +111,7 @@ def create_loggers(
     config["loggers"].setdefault(
         "{logger_name}",
         {
-            "level": "INFO",
+            "level": "DEBUG",
             "handlers": [
                 "{logger_name}.file",
             ],
@@ -121,6 +124,7 @@ def create_loggers(
             "class": "logging.StreamHandler",
             "stream": "ext://sys.stdout",
             "formatter": "tvm.meta_schedule.standard_formatter",
+            "level": console_logging_level,
         },
     )
     config["handlers"].setdefault(
@@ -129,7 +133,7 @@ def create_loggers(
             "class": "logging.FileHandler",
             "filename": "{log_dir}/" + __name__ + ".task_scheduler.log",
             "mode": "a",
-            "level": "INFO",
+            "level": "DEBUG",
             "formatter": "tvm.meta_schedule.standard_formatter",
         },
     )
@@ -139,14 +143,14 @@ def create_loggers(
             "class": "logging.FileHandler",
             "filename": "{log_dir}/{logger_name}.log",
             "mode": "a",
-            "level": "INFO",
+            "level": "DEBUG",
             "formatter": "tvm.meta_schedule.standard_formatter",
         },
     )
     config["formatters"].setdefault(
         "tvm.meta_schedule.standard_formatter",
         {
-            "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
+            "format": "%(asctime)s [%(levelname)s] %(message)s",
             "datefmt": "%Y-%m-%d %H:%M:%S",
         },
     )
diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py 
b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
index f06f4d911f..d56d944474 100644
--- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
+++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py
@@ -163,15 +163,9 @@ class TaskScheduler(Object):
         """
         _ffi_api.TaskSchedulerTouchTask(self, task_id)  # type: ignore # 
pylint: disable=no-member
 
-    def tuning_statistics(self) -> str:
-        """Returns a human-readable string of the tuning statistics.
-
-        Returns
-        -------
-        tuning_statistics : str
-            The tuning statistics.
-        """
-        return _ffi_api.TaskSchedulerTuningStatistics(self)  # type: ignore # 
pylint: disable=no-member
+    def print_tuning_statistics(self) -> None:
+        """Print out a human-readable format of the tuning statistics."""
+        return _ffi_api.TaskSchedulerPrintTuningStatistics(self)  # type: 
ignore # pylint: disable=no-member
 
     @staticmethod
     def create(  # pylint: disable=keyword-arg-before-vararg
diff --git a/python/tvm/meta_schedule/utils.py 
b/python/tvm/meta_schedule/utils.py
index eb3c643760..401fdab08a 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -188,13 +188,45 @@ def cpu_count(logical: bool = True) -> int:
 
 
 @register_func("meta_schedule.using_ipython")
-def _using_ipython():
+def _using_ipython() -> bool:
+    """Return whether the current process is running in an IPython shell.
+
+    Returns
+    -------
+    result : bool
+        Whether the current process is running in an IPython shell.
+    """
     try:
         return get_ipython().__class__.__name__ == "ZMQInteractiveShell"  # 
type: ignore
     except NameError:
         return False
 
 
+@register_func("meta_schedule.print_interactive_table")
+def print_interactive_table(data: str) -> None:
+    """Print the dataframe interactive table in notebook.
+
+    Parameters
+    ----------
+    data : str
+        The serialized performance table from MetaSchedule table printer.
+    """
+    import pandas as pd  # type: ignore # pylint: 
disable=import-outside-toplevel
+    from IPython.display import display  # type: ignore # pylint: 
disable=import-outside-toplevel
+
+    pd.set_option("display.max_rows", None)
+    pd.set_option("display.max_colwidth", None)
+    parsed = [
+        x.split("|")[1:] for x in list(filter(lambda x: set(x) != {"-"}, 
data.strip().split("\n")))
+    ]
+    display(
+        pd.DataFrame(
+            parsed[1:],
+            columns=parsed[0],
+        )
+    )
+
+
 def get_global_func_with_default_on_worker(
     name: Union[None, str, Callable],
     default: Callable,
diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc 
b/src/meta_schedule/task_scheduler/gradient_based.cc
index bae52573a0..e0470337b5 100644
--- a/src/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/meta_schedule/task_scheduler/gradient_based.cc
@@ -60,7 +60,8 @@ class GradientBasedNode final : public TaskSchedulerNode {
     int n_tasks = this->tasks_.size();
     // Step 1. Check if it's in round robin mode.
     if (round_robin_rounds_ == 0) {
-      TVM_PY_LOG(INFO, this->logger) << "\n" << this->TuningStatistics();
+      TVM_PY_LOG_CLEAR_SCREEN(this->logger);
+      this->PrintTuningStatistics();
     }
     if (round_robin_rounds_ < n_tasks) {
       return round_robin_rounds_++;
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc 
b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 21efde26d9..69a70f63c5 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -232,9 +232,8 @@ Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int 
task_id) {
   }
   TaskCleanUp(task, task_id, results);
   TVM_PY_LOG_CLEAR_SCREEN(this->logger);
-  TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << 
task->ctx->task_name
-                                 << "\n"
-                                 << this->TuningStatistics();
+  TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << 
task->ctx->task_name;
+  this->PrintTuningStatistics();
   return results;
 }
 
@@ -257,12 +256,11 @@ void TaskSchedulerNode::TerminateTask(int task_id) {
   --this->remaining_tasks_;
   TVM_PY_LOG_CLEAR_SCREEN(this->logger);
   TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id
-                                 << " has finished. Remaining task(s): " << 
this->remaining_tasks_
-                                 << "\n"
-                                 << this->TuningStatistics();
+                                 << " has finished. Remaining task(s): " << 
this->remaining_tasks_;
+  this->PrintTuningStatistics();
 }
 
-std::string TaskSchedulerNode::TuningStatistics() const {
+void TaskSchedulerNode::PrintTuningStatistics() {
   std::ostringstream os;
   int n_tasks = this->tasks_.size();
   int total_trials = 0;
@@ -307,11 +305,18 @@ std::string TaskSchedulerNode::TuningStatistics() const {
     }
   }
   p.Separator();
-  os << p.AsStr()                                  //
-     << "\nTotal trials: " << total_trials         //
+
+  os << "\nTotal trials: " << total_trials         //
      << "\nTotal latency (us): " << total_latency  //
      << "\n";
-  return os.str();
+
+  if (using_ipython()) {
+    print_interactive_table(p.AsStr());
+    std::cout << os.str() << std::endl << std::flush;
+    TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str();
+  } else {
+    TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str();
+  }
 }
 
 TaskScheduler TaskScheduler::PyTaskScheduler(
@@ -369,8 +374,8 @@ 
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask")
     .set_body_method<TaskScheduler>(&TaskSchedulerNode::TerminateTask);
 TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask")
     .set_body_method<TaskScheduler>(&TaskSchedulerNode::TouchTask);
-TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTuningStatistics")
-    .set_body_method<TaskScheduler>(&TaskSchedulerNode::TuningStatistics);
+TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics")
+    .set_body_method<TaskScheduler>(&TaskSchedulerNode::PrintTuningStatistics);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 41d8ffde55..b14717f4b2 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -82,36 +82,32 @@ class PyLogMessage {
     // FATAL not included
   };
 
-  explicit PyLogMessage(const char* file, int lineno, PackedFunc logger, Level 
logging_level)
-      : file_(file), lineno_(lineno), logger_(logger), 
logging_level_(logging_level) {
-    if (this->logger_ != nullptr) {
-      stream_ << "" << file_ << ":" << lineno_ << " ";
-    }
-  }
+  explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, 
Level logging_level)
+      : filename_(filename), lineno_(lineno), logger_(logger), 
logging_level_(logging_level) {}
 
   TVM_NO_INLINE ~PyLogMessage() {
     ICHECK(logging_level_ != Level::CLEAR)
         << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use 
TVM_PY_LOG_CLEAR_SCREEN.";
     if (this->logger_ != nullptr) {
-      logger_(static_cast<int>(logging_level_), stream_.str());
+      logger_(static_cast<int>(logging_level_), std::string(filename_), 
lineno_, stream_.str());
     } else {
       if (logging_level_ == Level::INFO) {
-        runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
+        runtime::detail::LogMessage(filename_, lineno_).stream() << 
stream_.str();
       } else if (logging_level_ == Level::WARNING) {
-        runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << 
stream_.str();
+        runtime::detail::LogMessage(filename_, lineno_).stream() << "Warning: 
" << stream_.str();
       } else if (logging_level_ == Level::ERROR) {
-        runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << 
stream_.str();
+        runtime::detail::LogMessage(filename_, lineno_).stream() << "Error: " 
<< stream_.str();
       } else if (logging_level_ == Level::DEBUG) {
-        runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << 
stream_.str();
+        runtime::detail::LogMessage(filename_, lineno_).stream() << "Debug: " 
<< stream_.str();
       } else {
-        runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
+        runtime::detail::LogFatal(filename_, lineno_).stream() << 
stream_.str();
       }
     }
   }
   std::ostringstream& stream() { return stream_; }
 
  private:
-  const char* file_;
+  const char* filename_;
   int lineno_;
   std::ostringstream stream_;
   PackedFunc logger_;
@@ -131,6 +127,18 @@ inline bool using_ipython() {
   return flag;
 }
 
+/*!
+ * \brief Print out the performance table interactively in jupyter notebook.
+ * \param str The serialized performance table.
+ */
+inline void print_interactive_table(const String& data) {
+  const auto* f_print_interactive_table =
+      runtime::Registry::Get("meta_schedule.print_interactive_table");
+  ICHECK(f_print_interactive_table->defined())
+      << "Cannot find print_interactive_table function in registry.";
+  (*f_print_interactive_table)(data);
+}
+
 /*!
  * \brief A helper function to clear logging output for ipython kernel and 
console.
  * \param file The file name.
@@ -139,7 +147,7 @@ inline bool using_ipython() {
  */
 inline void clear_logging(const char* file, int lineno, PackedFunc 
logging_func) {
   if (logging_func.defined() && using_ipython()) {
-    logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), "");
+    logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, 
"");
   } else {
     // this would clear all logging output in the console
     runtime::detail::LogMessage(file, lineno).stream() << 
"\033c\033[3J\033[2J\033[0m\033[H";

Reply via email to