This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b58e16  [TensorIR] GetProducer, GetConsumer (#506) (#9464)
7b58e16 is described below

commit 7b58e1686f7f7e54b53dfc3c4ca02384f8a7214e
Author: Junru Shao <[email protected]>
AuthorDate: Sat Nov 6 09:07:07 2021 -0700

    [TensorIR] GetProducer, GetConsumer (#506) (#9464)
    
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
---
 include/tvm/tir/schedule/schedule.h                | 12 ++++
 python/tvm/tir/schedule/schedule.py                | 30 +++++++++
 src/tir/schedule/concrete_schedule.cc              | 14 ++++
 src/tir/schedule/concrete_schedule.h               |  2 +
 src/tir/schedule/primitive.h                       | 14 ++++
 src/tir/schedule/primitive/get_block_loop.cc       | 78 ++++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |  4 ++
 src/tir/schedule/traced_schedule.cc                | 22 ++++++
 src/tir/schedule/traced_schedule.h                 |  2 +
 .../python/unittest/test_tir_schedule_utilities.py | 42 +++++++++++-
 10 files changed, 219 insertions(+), 1 deletion(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 7bfe605..ffd860d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -232,6 +232,18 @@ class ScheduleNode : public runtime::Object {
    * \return A list of child blocks
    */
   virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
+  /*!
+   * \brief Get the producer of a specific block
+   * \param block_rv The block in the query
+   * \return A list of blocks, the producers of the given block
+   */
+  virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
+  /*!
+   * \brief Get the consumers of a specific block
+   * \param block_rv The block to be queried
+   * \return A list of blocks, the consumers of the given block
+   */
+  virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
   /******** Schedule: Transform loops ********/
   /*!
    * \brief Fuse a list of consecutive loops into one. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 0790e4f..884eeb7 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -415,6 +415,36 @@ class Schedule(Object):
         """
         return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop)  # type: 
ignore # pylint: disable=no-member
 
+    def get_producers(self, block: BlockRV) -> List[BlockRV]:
+        """Get the producers of a specific block
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block in the query
+
+        Returns
+        -------
+        producers : List[BlockRV]
+            A list of producers of the given block
+        """
+        return _ffi_api.ScheduleGetProducers(self, block)  # type: ignore # 
pylint: disable=no-member
+
+    def get_consumers(self, block: BlockRV) -> List[BlockRV]:
+        """Get the consumers of a specific block
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block in the query
+
+        Returns
+        -------
+        consumers : List[BlockRV]
+            A list of consumers of the given block
+        """
+        return _ffi_api.ScheduleGetConsumers(self, block)  # type: ignore # 
pylint: disable=no-member
+
     ########## Schedule: Transform loops ##########
     def fuse(self, *loops: List[LoopRV]) -> LoopRV:
         """Fuse a list of consecutive loops into one. It requires:
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 1c741fb..4db4cd4 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -310,6 +310,20 @@ Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const 
LoopRV& loop_rv) {
   return result;
 }
 
+Array<BlockRV> ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  return CreateRV<BlockRV>(tir::GetProducers(state_, this->GetSRef(block_rv)));
+  TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_);
+  throw;
+}
+
+Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  return CreateRV<BlockRV>(tir::GetConsumers(state_, this->GetSRef(block_rv)));
+  TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_);
+  throw;
+}
+
 /******** Schedule: Transform loops ********/
 
 LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 199faf8..035c16f 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -90,6 +90,8 @@ class ConcreteScheduleNode : public ScheduleNode {
   Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
   Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
   Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
+  Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
+  Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
   /******** Schedule: Transform loops ********/
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& 
factors) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 4e9d00f..cc7e44d 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -105,6 +105,20 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
  * \return A list of leaf blocks inside a specific block/loop
  */
 Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& 
parent_sref);
+/*!
+ * \brief Get the producers of a specific block
+ * \param self The schedule state
+ * \param block_sref The block in the query
+ * \return A list of blocks, the producers of the given block
+ */
+Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& 
block_sref);
+/*!
+ * \brief Get the consumers of a specific block
+ * \param self The schedule state
+ * \param block_rv The block in the query
+ * \return A list of blocks, the consumers of the given block
+ */
+Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& 
block_sref);
 /******** Schedule: Transform loops ********/
 /*!
  * Split a loop into a list of consecutive loops. It requires:
diff --git a/src/tir/schedule/primitive/get_block_loop.cc 
b/src/tir/schedule/primitive/get_block_loop.cc
index 2c4e23d..c044de3 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -77,6 +77,34 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, 
const StmtSRef& parent
   return std::move(collector.result);
 }
 
+Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& 
block_sref) {
+  StmtSRef scope_root = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false,
+                                     /*require_stage_pipeline=*/false);
+  Array<Dependency> edges = 
self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
+  Array<StmtSRef> results;
+  results.reserve(edges.size());
+  for (const Dependency& edge : edges) {
+    if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
+      results.push_back(edge->src);
+    }
+  }
+  return results;
+}
+
+Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& 
block_sref) {
+  StmtSRef scope_root = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false,
+                                     /*require_stage_pipeline=*/false);
+  Array<Dependency> edges = 
self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
+  Array<StmtSRef> results;
+  results.reserve(edges.size());
+  for (const Dependency& edge : edges) {
+    if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
+      results.push_back(edge->dst);
+    }
+  }
+  return results;
+}
+
 /******** InstructionKind Registration ********/
 
 struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
@@ -159,9 +187,59 @@ struct GetChildBlocksTraits : public 
UnpackedInstTraits<GetChildBlocksTraits> {
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> {
+  static constexpr const char* kName = "GetProducers";
+  static constexpr bool kIsPure = true;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV 
block_rv) {
+    return sch->GetProducers(block_rv);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv) {
+    PythonAPICall py("get_producers");
+    py.Input("block", block_rv);
+    py.OutputList(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
+  static constexpr const char* kName = "GetConsumers";
+  static constexpr bool kIsPure = true;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV 
block_rv) {
+    return sch->GetConsumers(block_rv);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv) {
+    PythonAPICall py("get_consumers");
+    py.Input("block", block_rv);
+    py.OutputList(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
 TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
 TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index a1b582d..a411e40 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -142,6 +142,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
                  << ". Its value is: " << rv;
       throw;
     });
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
+    .set_body_method<Schedule>(&ScheduleNode::GetProducers);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
+    .set_body_method<Schedule>(&ScheduleNode::GetConsumers);
 /******** (FFI) Transform loops ********/
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index e05d187..4a028d1 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -119,6 +119,28 @@ Array<BlockRV> TracedScheduleNode::GetChildBlocks(const 
LoopRV& loop_rv) {
   return results;
 }
 
+Array<BlockRV> TracedScheduleNode::GetProducers(const BlockRV& block_rv) {
+  Array<BlockRV> results = ConcreteScheduleNode::GetProducers(block_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("GetProducers");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,  //
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{results.begin(), 
results.end()}));
+  return results;
+}
+
+Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
+  Array<BlockRV> results = ConcreteScheduleNode::GetConsumers(block_rv);
+
+  static const InstructionKind& kind = InstructionKind::Get("GetConsumers");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,  //
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{},
+                                      /*outputs=*/{results.begin(), 
results.end()}));
+  return results;
+}
+
 /******** Schedule: Transform loops ********/
 
 LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index ae726ad..ac36b9c 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -56,6 +56,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
   Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
   Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
+  Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
+  Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
   /******** Schedule: Transform loops ********/
   LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
   Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& 
factor_rvs) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py 
b/tests/python/unittest/test_tir_schedule_utilities.py
index 1596d08..d75bc14 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -36,13 +36,31 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
     for i, j in T.grid(128, 128):
         with T.block("init"):
             vi, vj = T.axis.remap("SS", [i, j])
-            C[vi, vj] = T.float32(0)
+            C[vi, vj] = 0.0
         for k in range(0, 128):
             with T.block("update"):
                 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
 
 
[email protected]_func
+def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None:
+    A = T.match_buffer(a, (1024, 1024))
+    B = T.match_buffer(b, (1024, 1024))
+    C = T.alloc_buffer((1024, 1024))
+    D = T.match_buffer(d, (1024, 1024))
+    for i, j, k in T.grid(1024, 1024, 1024):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+    for i, j in T.grid(1024, 1024):
+        with T.block("relu"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.max(C[vi, vj], 0.0)
+
+
 # pylint: enable=no-member,invalid-name,unused-variable
 
 
@@ -159,5 +177,27 @@ def test_get_child_blocks():
     assert s.get(update) == s.get(blocks[1])
 
 
+def test_get_producers():
+    sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
+    block = sch.get_block("relu")
+    (producer,) = sch.get_producers(block)
+    assert tvm.ir.structural_equal(
+        sch.get_sref(producer).stmt,
+        sch.get_sref(sch.get_block("matmul")).stmt,
+    )
+    verify_trace_roundtrip(sch, mod=matmul_relu)
+
+
+def test_get_consumers():
+    sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
+    block = sch.get_block("matmul")
+    (consumer,) = sch.get_consumers(block)
+    assert tvm.ir.structural_equal(
+        sch.get_sref(consumer).stmt,
+        sch.get_sref(sch.get_block("relu")).stmt,
+    )
+    verify_trace_roundtrip(sch, mod=matmul_relu)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to