This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4973cd3  [TIR][Schedule] Add get-child-blocks primitive (#9434)
4973cd3 is described below

commit 4973cd37cb3b0311c3d06b4bb6c18adaf168d43f
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Nov 6 00:07:34 2021 -0400

    [TIR][Schedule] Add get-child-blocks primitive (#9434)
    
    * get child blocks
    
    * fix
    
    * lint
    
    * fix
---
 include/tvm/tir/schedule/schedule.h                | 12 +++++
 python/tvm/tir/schedule/schedule.py                | 15 ++++++
 src/tir/schedule/concrete_schedule.cc              | 18 ++++++++
 src/tir/schedule/concrete_schedule.h               |  2 +
 src/tir/schedule/primitive.h                       |  7 +++
 src/tir/schedule/primitive/get_block_loop.cc       | 54 ++++++++++++++++++++++
 src/tir/schedule/schedule.cc                       | 12 +++++
 src/tir/schedule/traced_schedule.cc                | 22 +++++++++
 src/tir/schedule/traced_schedule.h                 |  2 +
 .../python/unittest/test_tir_schedule_utilities.py | 17 +++++++
 10 files changed, 161 insertions(+)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index dfbbf29..7bfe605 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -220,6 +220,18 @@ class ScheduleNode : public runtime::Object {
    * \return A list of loops above the given block in its scope, from outer to 
inner
    */
   virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
+  /*!
+   * \brief Get the leaf blocks of a specific scope
+   * \param block_rv The block where the scope is rooted
+   * \return A list of child blocks
+   */
+  virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
+  /*!
+   * \brief Get the leaf blocks of under a specific loop
+   * \param loop_rv The loop under which collecting is conducted
+   * \return A list of child blocks
+   */
+  virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
   /******** Schedule: Transform loops ********/
   /*!
    * \brief Fuse a list of consecutive loops into one. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 2663299..0790e4f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -400,6 +400,21 @@ class Schedule(Object):
         """
         return _ffi_api.ScheduleGetLoops(self, block)  # type: ignore # 
pylint: disable=no-member
 
+    def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> 
List[BlockRV]:
+        """Get the leaf blocks of a specific block/loop
+
+        Parameters
+        ----------
+        block_or_loop : Union[BlockRV, LoopRV]
+            The query block/loop
+
+        Returns
+        -------
+        blocks : List[LoopRV]
+            A list of leaf blocks inside a specific block/loop
+        """
+        return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop)  # type: 
ignore # pylint: disable=no-member
+
     ########## Schedule: Transform loops ##########
     def fuse(self, *loops: List[LoopRV]) -> LoopRV:
         """Fuse a list of consecutive loops into one. It requires:
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 6801eb2..1c741fb 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -292,6 +292,24 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const 
BlockRV& block_rv) {
   return CreateRV<LoopRV>(tir::GetLoops(this->GetSRef(block_rv)));
 }
 
+Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
+  Array<BlockRV> result;
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, 
this->GetSRef(block_rv)));
+  TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
+  this->state_->DebugVerify();
+  return result;
+}
+
+Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
+  Array<BlockRV> result;
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, 
this->GetSRef(loop_rv)));
+  TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
+  this->state_->DebugVerify();
+  return result;
+}
+
 /******** Schedule: Transform loops ********/
 
 LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 9dd3626..199faf8 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -88,6 +88,8 @@ class ConcreteScheduleNode : public ScheduleNode {
   /******** Schedule: Get blocks & loops ********/
   BlockRV GetBlock(const String& name, const String& func_name = "main") 
override;
   Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
+  Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
+  Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
   /******** Schedule: Transform loops ********/
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& 
factors) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f2da7e2..4e9d00f 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -98,6 +98,13 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const 
String& name, const S
  * \return A list of loops above the given block in its scope, from outer to 
inner
  */
 Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
+/*!
+ * \brief Get the leaf blocks of a specific block/loop
+ * \param self The schedule state
+ * \param parent_sref The query block/loop
+ * \return A list of leaf blocks inside a specific block/loop
+ */
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& 
parent_sref);
 /******** Schedule: Transform loops ********/
 /*!
  * Split a loop into a list of consecutive loops. It requires:
diff --git a/src/tir/schedule/primitive/get_block_loop.cc 
b/src/tir/schedule/primitive/get_block_loop.cc
index 8b32a9c..2c4e23d 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -55,6 +55,28 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref) {
   return {result.rbegin(), result.rend()};
 }
 
+Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& 
parent_sref) {
+  struct Collector : public StmtVisitor {
+   private:
+    void VisitStmt_(const BlockNode* block) final { 
result.push_back(self->stmt2ref.at(block)); }
+
+   public:
+    explicit Collector(const ScheduleState& self) : self(self) {}
+
+    const ScheduleState& self;
+    Array<StmtSRef> result;
+  };
+  Collector collector(self);
+  if (parent_sref->stmt->IsInstance<ForNode>()) {
+    const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
+    collector(loop->body);
+  } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
+    const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
+    collector(block->body);
+  }
+  return std::move(collector.result);
+}
+
 /******** InstructionKind Registration ********/
 
 struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
@@ -106,8 +128,40 @@ struct GetLoopsTraits : public 
UnpackedInstTraits<GetLoopsTraits> {
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
+  static constexpr const char* kName = "GetChildBlocks";
+  static constexpr bool kIsPure = true;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef 
block_or_loop_rv) {
+    if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
+      return sch->GetChildBlocks(GetRef<BlockRV>(block));
+    }
+    if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
+      return sch->GetChildBlocks(GetRef<LoopRV>(loop));
+    }
+    LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << 
block_or_loop_rv->GetTypeKey();
+    throw;
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String 
block_or_loop_rv) {
+    PythonAPICall py("get_child_blocks");
+    py.Input("", block_or_loop_rv);
+    py.OutputList(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
 TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 4ef456a..a1b582d 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -130,6 +130,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock")
     .set_body_method<Schedule>(&ScheduleNode::GetBlock);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops")
     .set_body_method<Schedule>(&ScheduleNode::GetLoops);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
+    .set_body_typed([](Schedule self, ObjectRef rv) {
+      if (const auto* block_rv = rv.as<BlockRVNode>()) {
+        return self->GetChildBlocks(GetRef<BlockRV>(block_rv));
+      }
+      if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+        return self->GetChildBlocks(GetRef<LoopRV>(loop_rv));
+      }
+      LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " 
<< rv->GetTypeKey()
+                 << ". Its value is: " << rv;
+      throw;
+    });
 /******** (FFI) Transform loops ********/
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index d1e103c..e05d187 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -97,6 +97,28 @@ Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV& 
block_rv) {
   return results;
 }
 
+Array<BlockRV> TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
+  Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(block_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,  //
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{results.begin(), 
results.end()}));
+  return results;
+}
+
+Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
+  Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(loop_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,  //
+                                      /*inputs=*/{loop_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{results.begin(), 
results.end()}));
+  return results;
+}
+
 /******** Schedule: Transform loops ********/
 
 LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 81e0fe84..ae726ad 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -54,6 +54,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   /******** Schedule: Get blocks & loops ********/
   BlockRV GetBlock(const String& name, const String& func_name = "main") final;
   Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
+  Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
+  Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
   /******** Schedule: Transform loops ********/
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& 
factor_rvs) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py 
b/tests/python/unittest/test_tir_schedule_utilities.py
index 440d0ab..1596d08 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -142,5 +142,22 @@ def test_tir_schedule_remove_rv():
         sch.get(block_rv)
 
 
+def test_get_child_blocks():
+    s = tir.Schedule(matmul, debug_mask="all")
+    init = s.get_block("init")
+    update = s.get_block("update")
+    # loop
+    blocks = s.get_child_blocks(s.get_loops(init)[0])
+    assert len(blocks) == 2
+    assert s.get(init) == s.get(blocks[0])
+    assert s.get(update) == s.get(blocks[1])
+    # block
+    root = s.get_block("root")
+    blocks = s.get_child_blocks(root)
+    assert len(blocks) == 2
+    assert s.get(init) == s.get(blocks[0])
+    assert s.get(update) == s.get(blocks[1])
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to