This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f8186d8c7d [TIR] Add sugar method `Schedule.work_on` (#11999)
f8186d8c7d is described below

commit f8186d8c7d3e4679a6dfd83d17521f20bfb3ca42
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jul 3 13:16:18 2022 -0700

    [TIR] Add sugar method `Schedule.work_on` (#11999)
    
    This PR introduces `Schedule.work_on`, which instructs
    `Schedule.get_block` to find the correct PrimFunc to retrieve from
    without having to specify `func_name` in every time if the PrimFunc's
    name is not `main`.
---
 include/tvm/tir/schedule/schedule.h                | 24 ++++++++++++-
 python/tvm/tir/schedule/schedule.py                | 25 +++++++++++--
 src/meta_schedule/arg_info.cc                      | 41 +++++++++++++++++++++
 src/meta_schedule/mutator/mutate_parallel.cc       |  3 +-
 src/meta_schedule/utils.h                          | 42 ----------------------
 src/tir/schedule/analysis.h                        |  9 +++++
 src/tir/schedule/analysis/analysis.cc              | 41 +++++++++++++++++++++
 src/tir/schedule/concrete_schedule.cc              | 25 +++++++++++--
 src/tir/schedule/concrete_schedule.h               |  8 +++--
 src/tir/schedule/primitive.h                       |  4 +--
 src/tir/schedule/primitive/get_block_loop.cc       |  4 +--
 src/tir/schedule/schedule.cc                       |  2 ++
 src/tir/schedule/traced_schedule.cc                | 21 +++++++++--
 src/tir/schedule/traced_schedule.h                 |  2 +-
 .../python/unittest/test_tir_schedule_utilities.py | 32 ++++++++++++++++-
 15 files changed, 225 insertions(+), 58 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index d95a9d4e7e..8e160c6132 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,21 @@ class ScheduleNode : public runtime::Object {
   virtual ScheduleState state() const = 0;
   /*! \return The internally maintained trace of scheduling program execution 
*/
   virtual Optional<Trace> trace() const = 0;
+  /*!
+   * \brief Instruct the schedule to work on a function in the IRModule.
+   *
+   * By default, the schedule works on the function with the name "main", or 
the only function in
+   * the IRModule if there is only one. If there is multiple functions in the 
IRModule, and none of
+   * their names are "main", users will have to call this method to explicitly 
specify which
+   * function to work on.
+   *
+   * This sugar function will guide the `GetBlock` method if its `func_name` 
is not specified.
+   *
+   * \param func_name The name of the function to be working on
+   *
+   * \sa GetBlock
+   */
+  virtual void WorkOn(const String& func_name) = 0;
   /*!
    * \brief Returns a copy of the schedule, including both its state and its 
symbol table,
    * guaranteeing that
@@ -231,12 +246,19 @@ class ScheduleNode : public runtime::Object {
   /******** Schedule: Get blocks & loops ********/
   /*!
    * \brief Retrieve a block in a specific function with its name
+   *
+   * By default, if `func_name` is not specified, the schedule will search for 
the block in the
+   * function that is currently being "worked on". To switch the function to 
be worked on, use
+   * `WorkOn` before calling this method.
+   *
    * \param name The name of the block to be retrieved
    * \param func_name The name of the function
    * \return The block retrieved
    * \note Indexing error is raised if 0 or multiple blocks exist with the 
specific name
+   *
+   * \sa WorkOn
    */
-  virtual BlockRV GetBlock(const String& name, const String& func_name = 
"main") = 0;
+  virtual BlockRV GetBlock(const String& name, const Optional<String>& 
func_name = NullOpt) = 0;
   /*!
    * \brief Get the parent loops of the block in its scope, from outer to inner
    * \param block_rv The query block
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 7a1e244604..28bdf63872 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -186,6 +186,23 @@ class Schedule(Object):
         """Returns the internally maintained trace of scheduling program 
execution"""
         return _ffi_api.ScheduleGetTrace(self)  # type: ignore # pylint: 
disable=no-member
 
+    def work_on(self, func_name: str) -> None:
+        """Instruct the schedule to work on a function in the IRModule.
+
+        By default, the schedule works on the function with the name "main", 
or the only function in
+        the IRModule if there is only one. If there is multiple functions in 
the IRModule, and none
+        of their names are "main", users will have to call this method to 
explicitly specify which
+        function to work on.
+
+        This sugar function will guide the `GetBlock` method if its 
`func_name` is not specified.
+
+        Parameters
+        ----------
+        func_name : str
+            The name of the function to work on.
+        """
+        _ffi_api.ScheduleWorkOn(self, func_name)  # type: ignore # pylint: 
disable=no-member
+
     def copy(self) -> "Schedule":
         """Returns a copy of the schedule, including both the state and the 
symbol table,
         * guaranteeing that
@@ -403,15 +420,19 @@ class Schedule(Object):
     def get_block(
         self,
         name: str,
-        func_name: str = "main",
+        func_name: Optional[str] = None,
     ) -> BlockRV:
         """Retrieve a block in a specific function with its name
 
+        By default, if `func_name` is not specified, the schedule will search 
for the block in the
+        function that is currently being "worked on". To switch the function 
to be worked on, use
+        `work_on` before calling this method.
+
         Parameters
         ----------
         name : str
             The name of the block
-        func_name : str = "main"
+        func_name : Optional[str] = None
             The name of the function
 
         Returns
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 672df86deb..21de9d719d 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -21,6 +21,47 @@
 namespace tvm {
 namespace meta_schedule {
 
+/*!
+ * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
+ * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * \param mod The IRModule to find the entry function.
+ * \return The entry function.
+ */
+inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
+  // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+  int num_prim_func = 0;
+  const tir::PrimFuncNode* main_func = nullptr;
+  const tir::PrimFuncNode* last_func = nullptr;
+  for (const auto& kv : mod->functions) {
+    GlobalVar gv = kv.first;
+    BaseFunc base_func = kv.second;
+    if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+      last_func = func;
+      if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+        return GetRef<tir::PrimFunc>(func);
+      }
+      if (gv->name_hint == "main") {
+        main_func = func;
+      }
+      ++num_prim_func;
+    }
+  }
+  // Priority 2: PrimFunc whose name is `main`
+  if (main_func != nullptr) {
+    return GetRef<tir::PrimFunc>(main_func);
+  }
+  // Priority 3: The only PrimFunc in the IRModule
+  if (num_prim_func == 0) {
+    LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: 
"
+               << tir::AsTVMScript(mod);
+  }
+  if (num_prim_func > 1) {
+    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are "
+                  "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
+               << tir::AsTVMScript(mod);
+  }
+  return GetRef<tir::PrimFunc>(last_func);
+}
 /******** ArgInfo ********/
 
 ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) {
diff --git a/src/meta_schedule/mutator/mutate_parallel.cc 
b/src/meta_schedule/mutator/mutate_parallel.cc
index 7c973879f2..5b7fe7f514 100644
--- a/src/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -79,7 +79,8 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& 
inst) {
 std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
                                                   const String& block_name, 
const String& func_name,
                                                   int64_t limit) {
-  Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
+  Array<StmtSRef> block_srefs =
+      tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
   ICHECK_EQ(block_srefs.size(), 1);
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
   ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index ca696da71e..b5cb73c26e 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -174,48 +174,6 @@ inline String SHash2Hex(const ObjectRef& obj) {
   return os.str();
 }
 
-/*!
- * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
- * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
- * \param mod The IRModule to find the entry function.
- * \return The entry function.
- */
-inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
-  // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
-  int num_prim_func = 0;
-  const tir::PrimFuncNode* main_func = nullptr;
-  const tir::PrimFuncNode* last_func = nullptr;
-  for (const auto& kv : mod->functions) {
-    GlobalVar gv = kv.first;
-    BaseFunc base_func = kv.second;
-    if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
-      last_func = func;
-      if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        return GetRef<tir::PrimFunc>(func);
-      }
-      if (gv->name_hint == "main") {
-        main_func = func;
-      }
-      ++num_prim_func;
-    }
-  }
-  // Priority 2: PrimFunc whose name is `main`
-  if (main_func != nullptr) {
-    return GetRef<tir::PrimFunc>(main_func);
-  }
-  // Priority 3: The only PrimFunc in the IRModule
-  if (num_prim_func == 0) {
-    LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: 
"
-               << tir::AsTVMScript(mod);
-  }
-  if (num_prim_func > 1) {
-    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are "
-                  "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
-               << tir::AsTVMScript(mod);
-  }
-  return GetRef<tir::PrimFunc>(last_func);
-}
-
 /*!
  * \brief Fork a random state into another, i.e. PRNG splitting.
  * The given random state is also mutated.
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index b30cef829f..317b3625f0 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -71,6 +71,15 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, 
const StmtNode* root_bl
  */
 StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
 
+/*!
+ * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
+ * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * \param mod The IRModule to find the entry function.
+ * \param result_g_var The result GlobalVar of the entry function.
+ * \return The entry function.
+ */
+const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* 
result_g_var);
+
 /******** Scope ********/
 /*!
  * \brief Checks if scope the specified sref is in is a stage-pipeline and 
return it
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 3ee1ed28b8..ac73ac3ce2 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -49,6 +49,47 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, 
const StmtNode* root_bl
   throw;
 }
 
+const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* 
result_g_var) {
+  GlobalVar result = NullValue<GlobalVar>();
+  // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+  int num_prim_func = 0;
+  const tir::PrimFuncNode* main_func = nullptr;
+  const tir::PrimFuncNode* last_func = nullptr;
+  for (const auto& kv : mod->functions) {
+    GlobalVar gv = kv.first;
+    BaseFunc base_func = kv.second;
+    if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+      last_func = func;
+      if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+        if (result_g_var != nullptr) {
+          *result_g_var = gv;
+        }
+        return func;
+      }
+      if (gv->name_hint == "main") {
+        main_func = func;
+        result = gv;
+      }
+      ++num_prim_func;
+    }
+  }
+  // Priority 2: PrimFunc whose name is `main`
+  if (main_func != nullptr) {
+    if (result_g_var != nullptr) {
+      *result_g_var = result;
+    }
+    return main_func;
+  }
+  // Priority 3: The only PrimFunc in the IRModule
+  if (num_prim_func == 1) {
+    if (result_g_var != nullptr) {
+      *result_g_var = result;
+    }
+    return last_func;
+  }
+  return nullptr;
+}
+
 /******** Scope ********/
 
 StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index b2f48753b5..c19735025d 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -31,6 +31,12 @@ Schedule Schedule::Concrete(IRModule mod, 
support::LinearCongruentialEngine::TRa
   n->symbol_table_ = {};
   n->analyzer_ = std::make_unique<arith::Analyzer>();
   n->Seed(seed);
+  GlobalVar gv = NullValue<GlobalVar>();
+  if (FindEntryFunc(mod, &gv) != nullptr) {
+    n->func_working_on_ = gv;
+  } else {
+    n->func_working_on_ = NullOpt;
+  }
   return Schedule(std::move(n));
 }
 
@@ -177,6 +183,10 @@ class ScheduleCopier {
   std::unordered_map<const StmtSRefNode*, StmtSRef> old2new_;
 };
 
+void ConcreteScheduleNode::WorkOn(const String& func_name) {
+  this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name);
+}
+
 void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* 
new_symbol_table) const {
   ScheduleCopier::Copy(this, new_state, new_symbol_table);
   new_state->get()->DebugVerify();
@@ -184,6 +194,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, 
TSymbolTable* new_symb
 
 Schedule ConcreteScheduleNode::Copy() {
   ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
+  n->func_working_on_ = this->func_working_on_;
   n->error_render_level_ = this->error_render_level_;
   ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
   n->analyzer_ = std::make_unique<arith::Analyzer>();  // new analyzer needed 
because it is stateful
@@ -251,7 +262,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const 
BlockRV& block_rv,
 
 /******** Schedule: Get blocks & loops ********/
 
-BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& 
func_name) {
+BlockRV ConcreteScheduleNode::GetBlock(const String& name, const 
Optional<String>& func_name) {
   class NotSingleResult : public ScheduleError {
    public:
     explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& 
blocks)
@@ -286,7 +297,17 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, 
const String& func_na
     IRModule mod_;
     Array<Block> blocks_;
   };
-  Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name);
+  GlobalVar gv = NullValue<GlobalVar>();
+  if (func_name.defined()) {
+    gv = state_->mod->GetGlobalVar(func_name.value());
+  } else if (func_working_on_.defined()) {
+    gv = this->func_working_on_.value();
+  } else {
+    LOG(FATAL) << "ValueError: `get_block` does not know which function to be 
working on. Please "
+                  "specify the function name explicitly, or call `work_on` to 
specify the function "
+                  "before using `get_block`.";
+  }
+  Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv);
   if (blocks.size() != 1) {
     TVM_TIR_SCHEDULE_BEGIN();
     throw NotSingleResult(name, this->state_->mod, blocks);
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index dfbacb530a..feea310bd7 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -38,6 +38,8 @@ class ConcreteScheduleNode : public ScheduleNode {
  protected:
   /*! \brief The internal state of scheduling */
   ScheduleState state_;
+  /*! \brief The function to be worked on. */
+  Optional<GlobalVar> func_working_on_;
   /*! \brief The level of error rendering */
   ScheduleErrorRenderLevel error_render_level_;
   /*! \brief A symbol table that maps random variables to concrete 
StmtSRef/Integers */
@@ -50,10 +52,11 @@ class ConcreteScheduleNode : public ScheduleNode {
  public:
   void VisitAttrs(tvm::AttrVisitor* v) {
     // `state_` is not visited
+    // `func_working_on_` is not visited
     // `error_render_level_` is not visited
     // `symbol_table_` is not visited
     // `analyzer_` is not visited
-    // `rand_state_` is not visited
+    // `rgnd_state_` is not visited
   }
 
   virtual ~ConcreteScheduleNode() = default;
@@ -61,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
  public:
   ScheduleState state() const final { return state_; }
   Optional<Trace> trace() const override { return NullOpt; }
+  void WorkOn(const String& func_name) final;
   Schedule Copy() override;
   void Seed(support::LinearCongruentialEngine::TRandState seed) final;
   support::LinearCongruentialEngine::TRandState ForkSeed() final;
@@ -89,7 +93,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   LoopRV SampleComputeLocation(const BlockRV& block_rv,
                                Optional<Integer> decision = NullOpt) override;
   /******** Schedule: Get blocks & loops ********/
-  BlockRV GetBlock(const String& name, const String& func_name = "main") 
override;
+  BlockRV GetBlock(const String& name, const Optional<String>& func_name) 
override;
   Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
   Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
   Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 212571df10..608368fbb3 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -116,10 +116,10 @@ TVM_DLL tir::StmtSRef SampleComputeLocation(
  * \brief Retrieves blocks in a specific function with its name
  * \param self The schedule state
  * \param name The name of the blocks to be retrieved
- * \param func_name The name of the function
+ * \param gvar The function to be retrieved
  * \return A list of blocks with the specific name
  */
-Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const 
String& func_name);
+Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const 
GlobalVar& gv);
 /*!
  * \brief Gets the parent loops of the block in its scope, from outer to inner
  * \param self The schedule state
diff --git a/src/tir/schedule/primitive/get_block_loop.cc 
b/src/tir/schedule/primitive/get_block_loop.cc
index a13e525157..746918ac4e 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -21,7 +21,7 @@
 namespace tvm {
 namespace tir {
 
-Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const 
String& func_name) {
+Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const 
GlobalVar& gv) {
   struct Finder : public StmtVisitor {
     explicit Finder(const ScheduleState& self, const String& name) : 
self_(self), name_(name) {}
 
@@ -39,7 +39,7 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const 
String& name, const S
     Array<StmtSRef> results_;
   };
 
-  BaseFunc func = self->mod->Lookup(func_name);
+  BaseFunc func = self->mod->Lookup(gv);
   const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);
   Finder finder(self, name);
   finder(prim_func->body);
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 372d94a150..e386061ebf 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -56,6 +56,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed")  //
     .set_body_method<Schedule>(&ScheduleNode::Seed);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed")  //
     .set_body_method<Schedule>(&ScheduleNode::ForkSeed);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn")  //
+    .set_body_method<Schedule>(&ScheduleNode::WorkOn);
 
 /**************** (FFI) Constructor ****************/
 
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 733b5d872f..93e4c984a4 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -30,6 +30,12 @@ Schedule Schedule::Traced(IRModule mod, 
support::LinearCongruentialEngine::TRand
   n->analyzer_ = std::make_unique<arith::Analyzer>();
   n->trace_ = Trace();
   n->Seed(seed);
+  GlobalVar gv = NullValue<GlobalVar>();
+  if (FindEntryFunc(mod, &gv) != nullptr) {
+    n->func_working_on_ = gv;
+  } else {
+    n->func_working_on_ = NullOpt;
+  }
   return Schedule(std::move(n));
 }
 
@@ -37,6 +43,7 @@ Schedule TracedScheduleNode::Copy() {
   ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
   n->error_render_level_ = this->error_render_level_;
   ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
+  n->func_working_on_ = this->func_working_on_;
   n->analyzer_ = std::make_unique<arith::Analyzer>();  // new analyzer needed 
because it is stateful
   n->rand_state_ = ForkSeed();
   n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
@@ -90,13 +97,23 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const 
BlockRV& block_rv,
 
 /******** Schedule: Get blocks & loops ********/
 
-BlockRV TracedScheduleNode::GetBlock(const String& name, const String& 
func_name) {
+BlockRV TracedScheduleNode::GetBlock(const String& name, const 
Optional<String>& func_name) {
+  GlobalVar gv = NullValue<GlobalVar>();
+  if (func_name.defined()) {
+    gv = state_->mod->GetGlobalVar(func_name.value());
+  } else if (func_working_on_.defined()) {
+    gv = this->func_working_on_.value();
+  } else {
+    LOG(FATAL) << "ValueError: `get_block` does not know which function to be 
working on. Please "
+                  "specify the function name explicitly, or call `work_on` to 
specify the function "
+                  "before using `get_block`.";
+  }
   BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);
 
   static const InstructionKind& kind = InstructionKind::Get("GetBlock");
   trace_->Append(/*inst=*/Instruction(/*kind=*/kind,  //
                                       /*inputs=*/{},
-                                      /*attrs=*/{name, func_name},
+                                      /*attrs=*/{name, gv->name_hint},
                                       /*outputs=*/{result}));
   return result;
 }
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 178026d9ea..f6405d77a1 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -53,7 +53,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
                                   Optional<Array<Integer>> decision = NullOpt) 
final;
   LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional<Integer> 
decision = NullOpt) final;
   /******** Schedule: Get blocks & loops ********/
-  BlockRV GetBlock(const String& name, const String& func_name = "main") final;
+  BlockRV GetBlock(const String& name, const Optional<String>& func_name) 
final;
   Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
   Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
   Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py 
b/tests/python/unittest/test_tir_schedule_utilities.py
index b7517aab7c..c479555590 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -20,7 +20,6 @@ import sys
 import pytest
 import tvm
 import tvm.testing
-
 from tvm import tir
 from tvm.ir import IRModule
 from tvm.script import tir as T
@@ -102,6 +101,29 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: 
T.handle) -> None:
             D[vi, vj] = T.max(C[vi, vj], 0.0)
 
 
[email protected]_module
+class ModuleWithMultipleFuncs:
+    @T.prim_func
+    def vector_add(
+        A: T.Buffer[128, "float32"],
+        B: T.Buffer[128, "float32"],
+    ) -> None:
+        for i in range(128):
+            with T.block("init"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = A[vi]
+
+    @T.prim_func
+    def vector_add_2(
+        A: T.Buffer[128, "float32"],
+        B: T.Buffer[128, "float32"],
+    ) -> None:
+        for i in range(128):
+            with T.block("init"):
+                vi = T.axis.remap("S", [i])
+                B[vi] = A[vi]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable
 
 use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, 
"block_name": True})
@@ -133,6 +155,14 @@ def test_tir_schedule_get_block():
     assert block.same_as(matmul.body.block.body.body.body[1].body.block)
 
 
+def test_tir_schedule_work_on():
+    sch = tir.Schedule(ModuleWithMultipleFuncs, debug_mask="all")
+    with pytest.raises(ValueError, match="does not know which function to be 
working on"):
+        sch.get_block(name="init")
+    sch.work_on(func_name="vector_add")
+    sch.get_block(name="init")
+
+
 def test_tir_schedule_get_loops(use_block_name):
     # Tests:
     # - Schedule.get_loops

Reply via email to