This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4973cd3 [TIR][Schedule] Add get-child-blocks primitive (#9434)
4973cd3 is described below
commit 4973cd37cb3b0311c3d06b4bb6c18adaf168d43f
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Nov 6 00:07:34 2021 -0400
[TIR][Schedule] Add get-child-blocks primitive (#9434)
* get child blocks
* fix
* lint
* fix
---
include/tvm/tir/schedule/schedule.h | 12 +++++
python/tvm/tir/schedule/schedule.py | 15 ++++++
src/tir/schedule/concrete_schedule.cc | 18 ++++++++
src/tir/schedule/concrete_schedule.h | 2 +
src/tir/schedule/primitive.h | 7 +++
src/tir/schedule/primitive/get_block_loop.cc | 54 ++++++++++++++++++++++
src/tir/schedule/schedule.cc | 12 +++++
src/tir/schedule/traced_schedule.cc | 22 +++++++++
src/tir/schedule/traced_schedule.h | 2 +
.../python/unittest/test_tir_schedule_utilities.py | 17 +++++++
10 files changed, 161 insertions(+)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index dfbbf29..7bfe605 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -220,6 +220,18 @@ class ScheduleNode : public runtime::Object {
* \return A list of loops above the given block in its scope, from outer to
inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
+ /*!
+ * \brief Get the leaf blocks of a specific scope
+ * \param block_rv The block where the scope is rooted
+ * \return A list of child blocks
+ */
+ virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
+ /*!
+ * \brief Get the leaf blocks of under a specific loop
+ * \param loop_rv The loop under which collecting is conducted
+ * \return A list of child blocks
+ */
+ virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 2663299..0790e4f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -400,6 +400,21 @@ class Schedule(Object):
"""
return _ffi_api.ScheduleGetLoops(self, block) # type: ignore #
pylint: disable=no-member
+ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) ->
List[BlockRV]:
+ """Get the leaf blocks of a specific block/loop
+
+ Parameters
+ ----------
+ block_or_loop : Union[BlockRV, LoopRV]
+ The query block/loop
+
+ Returns
+ -------
+ blocks : List[LoopRV]
+ A list of leaf blocks inside a specific block/loop
+ """
+ return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type:
ignore # pylint: disable=no-member
+
########## Schedule: Transform loops ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 6801eb2..1c741fb 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -292,6 +292,24 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const
BlockRV& block_rv) {
return CreateRV<LoopRV>(tir::GetLoops(this->GetSRef(block_rv)));
}
+Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
+ Array<BlockRV> result;
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = CreateRV<BlockRV>(tir::GetChildBlocks(state_,
this->GetSRef(block_rv)));
+ TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
+ this->state_->DebugVerify();
+ return result;
+}
+
+Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
+ Array<BlockRV> result;
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = CreateRV<BlockRV>(tir::GetChildBlocks(state_,
this->GetSRef(loop_rv)));
+ TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
+ this->state_->DebugVerify();
+ return result;
+}
+
/******** Schedule: Transform loops ********/
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 9dd3626..199faf8 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -88,6 +88,8 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main")
override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
+ Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
+ Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>&
factors) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f2da7e2..4e9d00f 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -98,6 +98,13 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const
String& name, const S
* \return A list of loops above the given block in its scope, from outer to
inner
*/
Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
+/*!
+ * \brief Get the leaf blocks of a specific block/loop
+ * \param self The schedule state
+ * \param parent_sref The query block/loop
+ * \return A list of leaf blocks inside a specific block/loop
+ */
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef&
parent_sref);
/******** Schedule: Transform loops ********/
/*!
* Split a loop into a list of consecutive loops. It requires:
diff --git a/src/tir/schedule/primitive/get_block_loop.cc
b/src/tir/schedule/primitive/get_block_loop.cc
index 8b32a9c..2c4e23d 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -55,6 +55,28 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref) {
return {result.rbegin(), result.rend()};
}
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef&
parent_sref) {
+ struct Collector : public StmtVisitor {
+ private:
+ void VisitStmt_(const BlockNode* block) final {
result.push_back(self->stmt2ref.at(block)); }
+
+ public:
+ explicit Collector(const ScheduleState& self) : self(self) {}
+
+ const ScheduleState& self;
+ Array<StmtSRef> result;
+ };
+ Collector collector(self);
+ if (parent_sref->stmt->IsInstance<ForNode>()) {
+ const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
+ collector(loop->body);
+ } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
+ const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
+ collector(block->body);
+ }
+ return std::move(collector.result);
+}
+
/******** InstructionKind Registration ********/
struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
@@ -106,8 +128,40 @@ struct GetLoopsTraits : public
UnpackedInstTraits<GetLoopsTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
+ static constexpr const char* kName = "GetChildBlocks";
+ static constexpr bool kIsPure = true;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef
block_or_loop_rv) {
+ if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
+ return sch->GetChildBlocks(GetRef<BlockRV>(block));
+ }
+ if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
+ return sch->GetChildBlocks(GetRef<LoopRV>(loop));
+ }
+ LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " <<
block_or_loop_rv->GetTypeKey();
+ throw;
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String
block_or_loop_rv) {
+ PythonAPICall py("get_child_blocks");
+ py.Input("", block_or_loop_rv);
+ py.OutputList(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 4ef456a..a1b582d 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -130,6 +130,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock")
.set_body_method<Schedule>(&ScheduleNode::GetBlock);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops")
.set_body_method<Schedule>(&ScheduleNode::GetLoops);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
+ .set_body_typed([](Schedule self, ObjectRef rv) {
+ if (const auto* block_rv = rv.as<BlockRVNode>()) {
+ return self->GetChildBlocks(GetRef<BlockRV>(block_rv));
+ }
+ if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+ return self->GetChildBlocks(GetRef<LoopRV>(loop_rv));
+ }
+ LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: "
<< rv->GetTypeKey()
+ << ". Its value is: " << rv;
+ throw;
+ });
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index d1e103c..e05d187 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -97,6 +97,28 @@ Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV&
block_rv) {
return results;
}
+Array<BlockRV> TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
+ Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(block_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
+ /*inputs=*/{block_rv},
+ /*attrs=*/{},
+ /*outputs=*/{results.begin(),
results.end()}));
+ return results;
+}
+
+Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
+ Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(loop_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
+ /*inputs=*/{loop_rv},
+ /*attrs=*/{},
+ /*outputs=*/{results.begin(),
results.end()}));
+ return results;
+}
+
/******** Schedule: Transform loops ********/
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 81e0fe84..ae726ad 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -54,6 +54,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") final;
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
+ Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
+ Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>&
factor_rvs) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py
b/tests/python/unittest/test_tir_schedule_utilities.py
index 440d0ab..1596d08 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -142,5 +142,22 @@ def test_tir_schedule_remove_rv():
sch.get(block_rv)
+def test_get_child_blocks():
+ s = tir.Schedule(matmul, debug_mask="all")
+ init = s.get_block("init")
+ update = s.get_block("update")
+ # loop
+ blocks = s.get_child_blocks(s.get_loops(init)[0])
+ assert len(blocks) == 2
+ assert s.get(init) == s.get(blocks[0])
+ assert s.get(update) == s.get(blocks[1])
+ # block
+ root = s.get_block("root")
+ blocks = s.get_child_blocks(root)
+ assert len(blocks) == 2
+ assert s.get(init) == s.get(blocks[0])
+ assert s.get(update) == s.get(blocks[1])
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))