This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e416226bd7 [Unity][MetaSchedule] BlockCollector focusing on current 
func (#14595)
e416226bd7 is described below

commit e416226bd723fce327dbf41be46cd7381205058f
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Apr 12 13:52:37 2023 -0400

    [Unity][MetaSchedule] BlockCollector focusing on current func (#14595)
    
    PR #11999 introduces the sugar method `work_on` to TIR Schedule, with
    a field `func_working_on_` newly added to the ScheduleNode. In some
    cases we may want to know which function a ScheduleNode is working on,
    which is not supported previously.
    
    Therefore, this PR introduces a method to ScheduleNode that returns
    the function (more accurately, GlobalVar) currently being worked on.
    With this we are able to know the function being worked on.
---
 include/tvm/tir/schedule/schedule.h                |  2 ++
 python/tvm/tir/schedule/schedule.py                |  7 ++++-
 src/meta_schedule/utils.h                          | 30 ++++++++++++++--------
 src/tir/schedule/concrete_schedule.h               |  1 +
 src/tir/schedule/schedule.cc                       |  2 ++
 .../python/unittest/test_tir_schedule_utilities.py |  1 +
 6 files changed, 32 insertions(+), 11 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 570560c62d..e7b7e1f453 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
   virtual ScheduleState state() const = 0;
   /*! \return The internally maintained trace of scheduling program execution 
*/
   virtual Optional<Trace> trace() const = 0;
+  /*! \return The GlobalVar of the func that the schedule is currently working 
on */
+  virtual Optional<GlobalVar> func_working_on() const = 0;
   /*!
    * \brief Instruct the schedule to work on a function in the IRModule.
    *
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 68f0b9454c..7221fa48b0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -19,7 +19,7 @@ from typing import Callable, Dict, List, Optional, Tuple, 
Union
 
 from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
-from tvm.ir import IRModule, PrimExpr
+from tvm.ir import GlobalVar, IRModule, PrimExpr
 from tvm.runtime import Object, String
 from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
@@ -207,6 +207,11 @@ class Schedule(Object):
         """Returns the internally maintained trace of scheduling program 
execution"""
         return _ffi_api.ScheduleGetTrace(self)  # type: ignore # pylint: 
disable=no-member
 
+    @property
+    def func_working_on(self) -> Optional[GlobalVar]:
+        """Returns the GlobalVar of the func that the schedule is currently 
working on"""
+        return _ffi_api.ScheduleGetFuncWorkingOn(self)  # type: ignore # 
pylint: disable=no-member
+
     def work_on(self, func_name: str) -> None:
         """Instruct the schedule to work on a function in the IRModule.
 
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 955381b740..753974571a 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -566,16 +566,26 @@ class BlockCollector : public tir::StmtVisitor {
   /*! \brief Entry point */
   Array<tir::BlockRV> Run() {
     std::vector<tir::BlockRV> results;
-    for (const auto& [gv, base_func] : sch_->mod()->functions) {
-      // `gv->name_hint` is the name of the function
-      // `base_func` can be PrimFunc or relay::Function
-      if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
-        func_name_ = gv->name_hint;
-        block_names_.clear();
-        blocks_to_collect_.clear();
-        VisitStmt(func->body);
-        for (const String& name : blocks_to_collect_) {
-          results.push_back(sch_->GetBlock(name, func_name_));
+    auto f_collect = [this, &results](tir::PrimFunc func, String func_name) {
+      func_name_ = func_name;
+      block_names_.clear();
+      blocks_to_collect_.clear();
+      VisitStmt(func->body);
+      for (const String& name : blocks_to_collect_) {
+        results.push_back(sch_->GetBlock(name, func_name_));
+      }
+    };
+
+    if (sch_->func_working_on().defined()) {
+      GlobalVar gv = sch_->func_working_on().value();
+      tir::PrimFunc func = Downcast<tir::PrimFunc>(sch_->mod()->functions[gv]);
+      f_collect(func, gv->name_hint);
+    } else {
+      for (const auto& [gv, base_func] : sch_->mod()->functions) {
+        // `gv->name_hint` is the name of the function
+        // `base_func` can be PrimFunc or relay::Function
+        if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+          f_collect(GetRef<tir::PrimFunc>(func), gv->name_hint);
         }
       }
     }
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 227288b232..d68683c45f 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
  public:
   ScheduleState state() const final { return state_; }
   Optional<Trace> trace() const override { return NullOpt; }
+  Optional<GlobalVar> func_working_on() const final { return func_working_on_; 
}
   void WorkOn(const String& func_name) final;
   Schedule Copy() override;
   void Seed(support::LinearCongruentialEngine::TRandState seed) final;
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index a0e39b74d3..ce28c39a81 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState")  //
     .set_body_method<Schedule>(&ScheduleNode::state);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace")  //
     .set_body_method<Schedule>(&ScheduleNode::trace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn")  //
+    .set_body_method<Schedule>(&ScheduleNode::func_working_on);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy")  //
     .set_body_method<Schedule>(&ScheduleNode::Copy);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed")  //
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py 
b/tests/python/unittest/test_tir_schedule_utilities.py
index 53ee6a58cd..0ce2f0ea91 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
         sch.get_block(name="init")
     sch.work_on(func_name="vector_add")
     sch.get_block(name="init")
+    assert sch.func_working_on == sch.mod.get_global_var("vector_add")
 
 
 def test_tir_schedule_get_loops(use_block_name):

Reply via email to