This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f8186d8c7d [TIR] Add sugar method `Schedule.work_on` (#11999)
f8186d8c7d is described below
commit f8186d8c7d3e4679a6dfd83d17521f20bfb3ca42
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jul 3 13:16:18 2022 -0700
[TIR] Add sugar method `Schedule.work_on` (#11999)
This PR introduces `Schedule.work_on`, which instructs
`Schedule.get_block` to find the correct PrimFunc to retrieve from
without having to specify `func_name` in every time if the PrimFunc's
name is not `main`.
---
include/tvm/tir/schedule/schedule.h | 24 ++++++++++++-
python/tvm/tir/schedule/schedule.py | 25 +++++++++++--
src/meta_schedule/arg_info.cc | 41 +++++++++++++++++++++
src/meta_schedule/mutator/mutate_parallel.cc | 3 +-
src/meta_schedule/utils.h | 42 ----------------------
src/tir/schedule/analysis.h | 9 +++++
src/tir/schedule/analysis/analysis.cc | 41 +++++++++++++++++++++
src/tir/schedule/concrete_schedule.cc | 25 +++++++++++--
src/tir/schedule/concrete_schedule.h | 8 +++--
src/tir/schedule/primitive.h | 4 +--
src/tir/schedule/primitive/get_block_loop.cc | 4 +--
src/tir/schedule/schedule.cc | 2 ++
src/tir/schedule/traced_schedule.cc | 21 +++++++++--
src/tir/schedule/traced_schedule.h | 2 +-
.../python/unittest/test_tir_schedule_utilities.py | 32 ++++++++++++++++-
15 files changed, 225 insertions(+), 58 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index d95a9d4e7e..8e160c6132 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,21 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution
*/
virtual Optional<Trace> trace() const = 0;
+ /*!
+ * \brief Instruct the schedule to work on a function in the IRModule.
+ *
+ * By default, the schedule works on the function with the name "main", or
the only function in
+ * the IRModule if there is only one. If there is multiple functions in the
IRModule, and none of
+ * their names are "main", users will have to call this method to explicitly
specify which
+ * function to work on.
+ *
+ * This sugar function will guide the `GetBlock` method if its `func_name`
is not specified.
+ *
+ * \param func_name The name of the function to be working on
+ *
+ * \sa GetBlock
+ */
+ virtual void WorkOn(const String& func_name) = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its
symbol table,
* guaranteeing that
@@ -231,12 +246,19 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Get blocks & loops ********/
/*!
* \brief Retrieve a block in a specific function with its name
+ *
+ * By default, if `func_name` is not specified, the schedule will search for
the block in the
+ * function that is currently being "worked on". To switch the function to
be worked on, use
+ * `WorkOn` before calling this method.
+ *
* \param name The name of the block to be retrieved
* \param func_name The name of the function
* \return The block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the
specific name
+ *
+ * \sa WorkOn
*/
- virtual BlockRV GetBlock(const String& name, const String& func_name =
"main") = 0;
+ virtual BlockRV GetBlock(const String& name, const Optional<String>&
func_name = NullOpt) = 0;
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \param block_rv The query block
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 7a1e244604..28bdf63872 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -186,6 +186,23 @@ class Schedule(Object):
"""Returns the internally maintained trace of scheduling program
execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint:
disable=no-member
+ def work_on(self, func_name: str) -> None:
+ """Instruct the schedule to work on a function in the IRModule.
+
+ By default, the schedule works on the function with the name "main",
or the only function in
+ the IRModule if there is only one. If there is multiple functions in
the IRModule, and none
+ of their names are "main", users will have to call this method to
explicitly specify which
+ function to work on.
+
+ This sugar function will guide the `GetBlock` method if its
`func_name` is not specified.
+
+ Parameters
+ ----------
+ func_name : str
+ The name of the function to work on.
+ """
+ _ffi_api.ScheduleWorkOn(self, func_name) # type: ignore # pylint:
disable=no-member
+
def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the
symbol table,
* guaranteeing that
@@ -403,15 +420,19 @@ class Schedule(Object):
def get_block(
self,
name: str,
- func_name: str = "main",
+ func_name: Optional[str] = None,
) -> BlockRV:
"""Retrieve a block in a specific function with its name
+ By default, if `func_name` is not specified, the schedule will search
for the block in the
+ function that is currently being "worked on". To switch the function
to be worked on, use
+ `work_on` before calling this method.
+
Parameters
----------
name : str
The name of the block
- func_name : str = "main"
+ func_name : Optional[str] = None
The name of the function
Returns
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 672df86deb..21de9d719d 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -21,6 +21,47 @@
namespace tvm {
namespace meta_schedule {
+/*!
+ * \brief Find the entry function of the given IRModule, i.e, functions marked
by
+ * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * \param mod The IRModule to find the entry function.
+ * \return The entry function.
+ */
+inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
+ // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+ int num_prim_func = 0;
+ const tir::PrimFuncNode* main_func = nullptr;
+ const tir::PrimFuncNode* last_func = nullptr;
+ for (const auto& kv : mod->functions) {
+ GlobalVar gv = kv.first;
+ BaseFunc base_func = kv.second;
+ if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+ last_func = func;
+ if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+ return GetRef<tir::PrimFunc>(func);
+ }
+ if (gv->name_hint == "main") {
+ main_func = func;
+ }
+ ++num_prim_func;
+ }
+ }
+ // Priority 2: PrimFunc whose name is `main`
+ if (main_func != nullptr) {
+ return GetRef<tir::PrimFunc>(main_func);
+ }
+ // Priority 3: The only PrimFunc in the IRModule
+ if (num_prim_func == 0) {
+ LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule:
"
+ << tir::AsTVMScript(mod);
+ }
+ if (num_prim_func > 1) {
+ LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but
none of them are "
+ "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
+ << tir::AsTVMScript(mod);
+ }
+ return GetRef<tir::PrimFunc>(last_func);
+}
/******** ArgInfo ********/
ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) {
diff --git a/src/meta_schedule/mutator/mutate_parallel.cc
b/src/meta_schedule/mutator/mutate_parallel.cc
index 7c973879f2..5b7fe7f514 100644
--- a/src/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -79,7 +79,8 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction&
inst) {
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
const String& block_name,
const String& func_name,
int64_t limit) {
- Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
+ Array<StmtSRef> block_srefs =
+ tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
ICHECK_EQ(block_srefs.size(), 1);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index ca696da71e..b5cb73c26e 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -174,48 +174,6 @@ inline String SHash2Hex(const ObjectRef& obj) {
return os.str();
}
-/*!
- * \brief Find the entry function of the given IRModule, i.e, functions marked
by
- * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
- * \param mod The IRModule to find the entry function.
- * \return The entry function.
- */
-inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
- // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
- int num_prim_func = 0;
- const tir::PrimFuncNode* main_func = nullptr;
- const tir::PrimFuncNode* last_func = nullptr;
- for (const auto& kv : mod->functions) {
- GlobalVar gv = kv.first;
- BaseFunc base_func = kv.second;
- if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
- last_func = func;
- if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- return GetRef<tir::PrimFunc>(func);
- }
- if (gv->name_hint == "main") {
- main_func = func;
- }
- ++num_prim_func;
- }
- }
- // Priority 2: PrimFunc whose name is `main`
- if (main_func != nullptr) {
- return GetRef<tir::PrimFunc>(main_func);
- }
- // Priority 3: The only PrimFunc in the IRModule
- if (num_prim_func == 0) {
- LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule:
"
- << tir::AsTVMScript(mod);
- }
- if (num_prim_func > 1) {
- LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but
none of them are "
- "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
- << tir::AsTVMScript(mod);
- }
- return GetRef<tir::PrimFunc>(last_func);
-}
-
/*!
* \brief Fork a random state into another, i.e. PRNG splitting.
* The given random state is also mutated.
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index b30cef829f..317b3625f0 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -71,6 +71,15 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod,
const StmtNode* root_bl
*/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
+/*!
+ * \brief Find the entry function of the given IRModule, i.e, functions marked
by
+ * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * \param mod The IRModule to find the entry function.
+ * \param result_g_var The result GlobalVar of the entry function.
+ * \return The entry function.
+ */
+const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar*
result_g_var);
+
/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and
return it
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 3ee1ed28b8..ac73ac3ce2 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -49,6 +49,47 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod,
const StmtNode* root_bl
throw;
}
+const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar*
result_g_var) {
+ GlobalVar result = NullValue<GlobalVar>();
+ // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+ int num_prim_func = 0;
+ const tir::PrimFuncNode* main_func = nullptr;
+ const tir::PrimFuncNode* last_func = nullptr;
+ for (const auto& kv : mod->functions) {
+ GlobalVar gv = kv.first;
+ BaseFunc base_func = kv.second;
+ if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
+ last_func = func;
+ if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+ if (result_g_var != nullptr) {
+ *result_g_var = gv;
+ }
+ return func;
+ }
+ if (gv->name_hint == "main") {
+ main_func = func;
+ result = gv;
+ }
+ ++num_prim_func;
+ }
+ }
+ // Priority 2: PrimFunc whose name is `main`
+ if (main_func != nullptr) {
+ if (result_g_var != nullptr) {
+ *result_g_var = result;
+ }
+ return main_func;
+ }
+ // Priority 3: The only PrimFunc in the IRModule
+ if (num_prim_func == 1) {
+ if (result_g_var != nullptr) {
+ *result_g_var = result;
+ }
+ return last_func;
+ }
+ return nullptr;
+}
+
/******** Scope ********/
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index b2f48753b5..c19735025d 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -31,6 +31,12 @@ Schedule Schedule::Concrete(IRModule mod,
support::LinearCongruentialEngine::TRa
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->Seed(seed);
+ GlobalVar gv = NullValue<GlobalVar>();
+ if (FindEntryFunc(mod, &gv) != nullptr) {
+ n->func_working_on_ = gv;
+ } else {
+ n->func_working_on_ = NullOpt;
+ }
return Schedule(std::move(n));
}
@@ -177,6 +183,10 @@ class ScheduleCopier {
std::unordered_map<const StmtSRefNode*, StmtSRef> old2new_;
};
+void ConcreteScheduleNode::WorkOn(const String& func_name) {
+ this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name);
+}
+
void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable*
new_symbol_table) const {
ScheduleCopier::Copy(this, new_state, new_symbol_table);
new_state->get()->DebugVerify();
@@ -184,6 +194,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state,
TSymbolTable* new_symb
Schedule ConcreteScheduleNode::Copy() {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
+ n->func_working_on_ = this->func_working_on_;
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed
because it is stateful
@@ -251,7 +262,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const
BlockRV& block_rv,
/******** Schedule: Get blocks & loops ********/
-BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String&
func_name) {
+BlockRV ConcreteScheduleNode::GetBlock(const String& name, const
Optional<String>& func_name) {
class NotSingleResult : public ScheduleError {
public:
explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>&
blocks)
@@ -286,7 +297,17 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name,
const String& func_na
IRModule mod_;
Array<Block> blocks_;
};
- Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name);
+ GlobalVar gv = NullValue<GlobalVar>();
+ if (func_name.defined()) {
+ gv = state_->mod->GetGlobalVar(func_name.value());
+ } else if (func_working_on_.defined()) {
+ gv = this->func_working_on_.value();
+ } else {
+ LOG(FATAL) << "ValueError: `get_block` does not know which function to be
working on. Please "
+ "specify the function name explicitly, or call `work_on` to
specify the function "
+ "before using `get_block`.";
+ }
+ Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv);
if (blocks.size() != 1) {
TVM_TIR_SCHEDULE_BEGIN();
throw NotSingleResult(name, this->state_->mod, blocks);
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index dfbacb530a..feea310bd7 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -38,6 +38,8 @@ class ConcreteScheduleNode : public ScheduleNode {
protected:
/*! \brief The internal state of scheduling */
ScheduleState state_;
+ /*! \brief The function to be worked on. */
+ Optional<GlobalVar> func_working_on_;
/*! \brief The level of error rendering */
ScheduleErrorRenderLevel error_render_level_;
/*! \brief A symbol table that maps random variables to concrete
StmtSRef/Integers */
@@ -50,10 +52,11 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `state_` is not visited
+ // `func_working_on_` is not visited
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visited
- // `rand_state_` is not visited
+ // `rgnd_state_` is not visited
}
virtual ~ConcreteScheduleNode() = default;
@@ -61,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
+ void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
support::LinearCongruentialEngine::TRandState ForkSeed() final;
@@ -89,7 +93,7 @@ class ConcreteScheduleNode : public ScheduleNode {
LoopRV SampleComputeLocation(const BlockRV& block_rv,
Optional<Integer> decision = NullOpt) override;
/******** Schedule: Get blocks & loops ********/
- BlockRV GetBlock(const String& name, const String& func_name = "main")
override;
+ BlockRV GetBlock(const String& name, const Optional<String>& func_name)
override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 212571df10..608368fbb3 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -116,10 +116,10 @@ TVM_DLL tir::StmtSRef SampleComputeLocation(
* \brief Retrieves blocks in a specific function with its name
* \param self The schedule state
* \param name The name of the blocks to be retrieved
- * \param func_name The name of the function
+ * \param gvar The function to be retrieved
* \return A list of blocks with the specific name
*/
-Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const
String& func_name);
+Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const
GlobalVar& gv);
/*!
* \brief Gets the parent loops of the block in its scope, from outer to inner
* \param self The schedule state
diff --git a/src/tir/schedule/primitive/get_block_loop.cc
b/src/tir/schedule/primitive/get_block_loop.cc
index a13e525157..746918ac4e 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -21,7 +21,7 @@
namespace tvm {
namespace tir {
-Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const
String& func_name) {
+Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const
GlobalVar& gv) {
struct Finder : public StmtVisitor {
explicit Finder(const ScheduleState& self, const String& name) :
self_(self), name_(name) {}
@@ -39,7 +39,7 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const
String& name, const S
Array<StmtSRef> results_;
};
- BaseFunc func = self->mod->Lookup(func_name);
+ BaseFunc func = self->mod->Lookup(gv);
const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);
Finder finder(self, name);
finder(prim_func->body);
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 372d94a150..e386061ebf 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -56,6 +56,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
.set_body_method<Schedule>(&ScheduleNode::Seed);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") //
.set_body_method<Schedule>(&ScheduleNode::ForkSeed);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") //
+ .set_body_method<Schedule>(&ScheduleNode::WorkOn);
/**************** (FFI) Constructor ****************/
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 733b5d872f..93e4c984a4 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -30,6 +30,12 @@ Schedule Schedule::Traced(IRModule mod,
support::LinearCongruentialEngine::TRand
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->trace_ = Trace();
n->Seed(seed);
+ GlobalVar gv = NullValue<GlobalVar>();
+ if (FindEntryFunc(mod, &gv) != nullptr) {
+ n->func_working_on_ = gv;
+ } else {
+ n->func_working_on_ = NullOpt;
+ }
return Schedule(std::move(n));
}
@@ -37,6 +43,7 @@ Schedule TracedScheduleNode::Copy() {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
+ n->func_working_on_ = this->func_working_on_;
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed
because it is stateful
n->rand_state_ = ForkSeed();
n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
@@ -90,13 +97,23 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const
BlockRV& block_rv,
/******** Schedule: Get blocks & loops ********/
-BlockRV TracedScheduleNode::GetBlock(const String& name, const String&
func_name) {
+BlockRV TracedScheduleNode::GetBlock(const String& name, const
Optional<String>& func_name) {
+ GlobalVar gv = NullValue<GlobalVar>();
+ if (func_name.defined()) {
+ gv = state_->mod->GetGlobalVar(func_name.value());
+ } else if (func_working_on_.defined()) {
+ gv = this->func_working_on_.value();
+ } else {
+ LOG(FATAL) << "ValueError: `get_block` does not know which function to be
working on. Please "
+ "specify the function name explicitly, or call `work_on` to
specify the function "
+ "before using `get_block`.";
+ }
BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);
static const InstructionKind& kind = InstructionKind::Get("GetBlock");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{},
- /*attrs=*/{name, func_name},
+ /*attrs=*/{name, gv->name_hint},
/*outputs=*/{result}));
return result;
}
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 178026d9ea..f6405d77a1 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -53,7 +53,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Optional<Array<Integer>> decision = NullOpt)
final;
LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional<Integer>
decision = NullOpt) final;
/******** Schedule: Get blocks & loops ********/
- BlockRV GetBlock(const String& name, const String& func_name = "main") final;
+ BlockRV GetBlock(const String& name, const Optional<String>& func_name)
final;
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py
b/tests/python/unittest/test_tir_schedule_utilities.py
index b7517aab7c..c479555590 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -20,7 +20,6 @@ import sys
import pytest
import tvm
import tvm.testing
-
from tvm import tir
from tvm.ir import IRModule
from tvm.script import tir as T
@@ -102,6 +101,29 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d:
T.handle) -> None:
D[vi, vj] = T.max(C[vi, vj], 0.0)
[email protected]_module
+class ModuleWithMultipleFuncs:
+ @T.prim_func
+ def vector_add(
+ A: T.Buffer[128, "float32"],
+ B: T.Buffer[128, "float32"],
+ ) -> None:
+ for i in range(128):
+ with T.block("init"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi]
+
+ @T.prim_func
+ def vector_add_2(
+ A: T.Buffer[128, "float32"],
+ B: T.Buffer[128, "float32"],
+ ) -> None:
+ for i in range(128):
+ with T.block("init"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi]
+
+
# pylint: enable=no-member,invalid-name,unused-variable
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False,
"block_name": True})
@@ -133,6 +155,14 @@ def test_tir_schedule_get_block():
assert block.same_as(matmul.body.block.body.body.body[1].body.block)
+def test_tir_schedule_work_on():
+ sch = tir.Schedule(ModuleWithMultipleFuncs, debug_mask="all")
+ with pytest.raises(ValueError, match="does not know which function to be
working on"):
+ sch.get_block(name="init")
+ sch.work_on(func_name="vector_add")
+ sch.get_block(name="init")
+
+
def test_tir_schedule_get_loops(use_block_name):
# Tests:
# - Schedule.get_loops