This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e416226bd7 [Unity][MetaSchedule] BlockCollector focusing on current
func (#14595)
e416226bd7 is described below
commit e416226bd723fce327dbf41be46cd7381205058f
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Apr 12 13:52:37 2023 -0400
[Unity][MetaSchedule] BlockCollector focusing on current func (#14595)
PR #11999 introduces the sugar method `work_on` to TIR Schedule, with
a field `func_working_on_` newly added to the ScheduleNode. In some
cases we may want to know which function a ScheduleNode is working on,
which is not supported previously.
Therefore, this PR introduces a method to ScheduleNode that returns
the function (more accurately, GlobalVar) currently being worked on.
With this we are able to know the function being worked on.
---
include/tvm/tir/schedule/schedule.h | 2 ++
python/tvm/tir/schedule/schedule.py | 7 ++++-
src/meta_schedule/utils.h | 30 ++++++++++++++--------
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/schedule.cc | 2 ++
.../python/unittest/test_tir_schedule_utilities.py | 1 +
6 files changed, 32 insertions(+), 11 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 570560c62d..e7b7e1f453 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution
*/
virtual Optional<Trace> trace() const = 0;
+ /*! \return The GlobalVar of the func that the schedule is currently working
on */
+ virtual Optional<GlobalVar> func_working_on() const = 0;
/*!
* \brief Instruct the schedule to work on a function in the IRModule.
*
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 68f0b9454c..7221fa48b0 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -19,7 +19,7 @@ from typing import Callable, Dict, List, Optional, Tuple,
Union
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
-from tvm.ir import IRModule, PrimExpr
+from tvm.ir import GlobalVar, IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
@@ -207,6 +207,11 @@ class Schedule(Object):
"""Returns the internally maintained trace of scheduling program
execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint:
disable=no-member
+ @property
+ def func_working_on(self) -> Optional[GlobalVar]:
+ """Returns the GlobalVar of the func that the schedule is currently
working on"""
+ return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore #
pylint: disable=no-member
+
def work_on(self, func_name: str) -> None:
"""Instruct the schedule to work on a function in the IRModule.
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 955381b740..753974571a 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -566,16 +566,26 @@ class BlockCollector : public tir::StmtVisitor {
/*! \brief Entry point */
Array<tir::BlockRV> Run() {
std::vector<tir::BlockRV> results;
- for (const auto& [gv, base_func] : sch_->mod()->functions) {
- // `gv->name_hint` is the name of the function
- // `base_func` can be PrimFunc or relay::Function
- if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
- func_name_ = gv->name_hint;
- block_names_.clear();
- blocks_to_collect_.clear();
- VisitStmt(func->body);
- for (const String& name : blocks_to_collect_) {
- results.push_back(sch_->GetBlock(name, func_name_));
+ auto f_collect = [this, &results](tir::PrimFunc func, String func_name) {
+ func_name_ = func_name;
+ block_names_.clear();
+ blocks_to_collect_.clear();
+ VisitStmt(func->body);
+ for (const String& name : blocks_to_collect_) {
+ results.push_back(sch_->GetBlock(name, func_name_));
+ }
+ };
+
+ if (sch_->func_working_on().defined()) {
+ GlobalVar gv = sch_->func_working_on().value();
+ tir::PrimFunc func = Downcast<tir::PrimFunc>(sch_->mod()->functions[gv]);
+ f_collect(func, gv->name_hint);
+ } else {
+ for (const auto& [gv, base_func] : sch_->mod()->functions) {
+ // `gv->name_hint` is the name of the function
+ // `base_func` can be PrimFunc or relay::Function
+ if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+ f_collect(GetRef<tir::PrimFunc>(func), gv->name_hint);
}
}
}
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 227288b232..d68683c45f 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
+ Optional<GlobalVar> func_working_on() const final { return func_working_on_;
}
void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index a0e39b74d3..ce28c39a81 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") //
+ .set_body_method<Schedule>(&ScheduleNode::func_working_on);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
.set_body_method<Schedule>(&ScheduleNode::Copy);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py
b/tests/python/unittest/test_tir_schedule_utilities.py
index 53ee6a58cd..0ce2f0ea91 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
sch.get_block(name="init")
sch.work_on(func_name="vector_add")
sch.get_block(name="init")
+ assert sch.func_working_on == sch.mod.get_global_var("vector_add")
def test_tir_schedule_get_loops(use_block_name):