This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7b58e16 [TensorIR] GetProducer, GetConsumer (#506) (#9464)
7b58e16 is described below
commit 7b58e1686f7f7e54b53dfc3c4ca02384f8a7214e
Author: Junru Shao <[email protected]>
AuthorDate: Sat Nov 6 09:07:07 2021 -0700
[TensorIR] GetProducer, GetConsumer (#506) (#9464)
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
---
include/tvm/tir/schedule/schedule.h | 12 ++++
python/tvm/tir/schedule/schedule.py | 30 +++++++++
src/tir/schedule/concrete_schedule.cc | 14 ++++
src/tir/schedule/concrete_schedule.h | 2 +
src/tir/schedule/primitive.h | 14 ++++
src/tir/schedule/primitive/get_block_loop.cc | 78 ++++++++++++++++++++++
src/tir/schedule/schedule.cc | 4 ++
src/tir/schedule/traced_schedule.cc | 22 ++++++
src/tir/schedule/traced_schedule.h | 2 +
.../python/unittest/test_tir_schedule_utilities.py | 42 +++++++++++-
10 files changed, 219 insertions(+), 1 deletion(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 7bfe605..ffd860d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -232,6 +232,18 @@ class ScheduleNode : public runtime::Object {
* \return A list of child blocks
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
+ /*!
+ * \brief Get the producer of a specific block
+ * \param block_rv The block in the query
+ * \return A list of blocks, the producers of the given block
+ */
+ virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
+ /*!
+ * \brief Get the consumers of a specific block
+ * \param block_rv The block to be queried
+ * \return A list of blocks, the consumers of the given block
+ */
+ virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 0790e4f..884eeb7 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -415,6 +415,36 @@ class Schedule(Object):
"""
return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type:
ignore # pylint: disable=no-member
+ def get_producers(self, block: BlockRV) -> List[BlockRV]:
+ """Get the producers of a specific block
+
+ Parameters
+ ----------
+ block : BlockRV
+ The block in the query
+
+ Returns
+ -------
+ producers : List[BlockRV]
+ A list of producers of the given block
+ """
+ return _ffi_api.ScheduleGetProducers(self, block) # type: ignore #
pylint: disable=no-member
+
+ def get_consumers(self, block: BlockRV) -> List[BlockRV]:
+ """Get the consumers of a specific block
+
+ Parameters
+ ----------
+ block : BlockRV
+ The block in the query
+
+ Returns
+ -------
+ consumers : List[BlockRV]
+ A list of consumers of the given block
+ """
+ return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore #
pylint: disable=no-member
+
########## Schedule: Transform loops ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 1c741fb..4db4cd4 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -310,6 +310,20 @@ Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const
LoopRV& loop_rv) {
return result;
}
+Array<BlockRV> ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ return CreateRV<BlockRV>(tir::GetProducers(state_, this->GetSRef(block_rv)));
+ TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_);
+ throw;
+}
+
+Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ return CreateRV<BlockRV>(tir::GetConsumers(state_, this->GetSRef(block_rv)));
+ TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_);
+ throw;
+}
+
/******** Schedule: Transform loops ********/
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 199faf8..035c16f 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -90,6 +90,8 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
+ Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
+ Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>&
factors) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 4e9d00f..cc7e44d 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -105,6 +105,20 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
* \return A list of leaf blocks inside a specific block/loop
*/
Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef&
parent_sref);
+/*!
+ * \brief Get the producers of a specific block
+ * \param self The schedule state
+ * \param block_sref The block in the query
+ * \return A list of blocks, the producers of the given block
+ */
+Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef&
block_sref);
+/*!
+ * \brief Get the consumers of a specific block
+ * \param self The schedule state
+ * \param block_rv The block in the query
+ * \return A list of blocks, the consumers of the given block
+ */
+Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef&
block_sref);
/******** Schedule: Transform loops ********/
/*!
* Split a loop into a list of consecutive loops. It requires:
diff --git a/src/tir/schedule/primitive/get_block_loop.cc
b/src/tir/schedule/primitive/get_block_loop.cc
index 2c4e23d..c044de3 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -77,6 +77,34 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self,
const StmtSRef& parent
return std::move(collector.result);
}
+Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef&
block_sref) {
+ StmtSRef scope_root = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false,
+ /*require_stage_pipeline=*/false);
+ Array<Dependency> edges =
self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
+ Array<StmtSRef> results;
+ results.reserve(edges.size());
+ for (const Dependency& edge : edges) {
+ if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
+ results.push_back(edge->src);
+ }
+ }
+ return results;
+}
+
+Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef&
block_sref) {
+ StmtSRef scope_root = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false,
+ /*require_stage_pipeline=*/false);
+ Array<Dependency> edges =
self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
+ Array<StmtSRef> results;
+ results.reserve(edges.size());
+ for (const Dependency& edge : edges) {
+ if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
+ results.push_back(edge->dst);
+ }
+ }
+ return results;
+}
+
/******** InstructionKind Registration ********/
struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
@@ -159,9 +187,59 @@ struct GetChildBlocksTraits : public
UnpackedInstTraits<GetChildBlocksTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> {
+ static constexpr const char* kName = "GetProducers";
+ static constexpr bool kIsPure = true;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV
block_rv) {
+ return sch->GetProducers(block_rv);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv) {
+ PythonAPICall py("get_producers");
+ py.Input("block", block_rv);
+ py.OutputList(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
+ static constexpr const char* kName = "GetConsumers";
+ static constexpr bool kIsPure = true;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV
block_rv) {
+ return sch->GetConsumers(block_rv);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv) {
+ PythonAPICall py("get_consumers");
+ py.Input("block", block_rv);
+ py.OutputList(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
+TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index a1b582d..a411e40 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -142,6 +142,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
<< ". Its value is: " << rv;
throw;
});
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
+ .set_body_method<Schedule>(&ScheduleNode::GetProducers);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
+ .set_body_method<Schedule>(&ScheduleNode::GetConsumers);
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index e05d187..4a028d1 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -119,6 +119,28 @@ Array<BlockRV> TracedScheduleNode::GetChildBlocks(const
LoopRV& loop_rv) {
return results;
}
+Array<BlockRV> TracedScheduleNode::GetProducers(const BlockRV& block_rv) {
+ Array<BlockRV> results = ConcreteScheduleNode::GetProducers(block_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("GetProducers");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
+ /*inputs=*/{block_rv},
+ /*attrs=*/{},
+ /*outputs=*/{results.begin(),
results.end()}));
+ return results;
+}
+
+Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
+ Array<BlockRV> results = ConcreteScheduleNode::GetConsumers(block_rv);
+
+ static const InstructionKind& kind = InstructionKind::Get("GetConsumers");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
+ /*inputs=*/{block_rv},
+ /*attrs=*/{},
+ /*outputs=*/{results.begin(),
results.end()}));
+ return results;
+}
+
/******** Schedule: Transform loops ********/
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index ae726ad..ac36b9c 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -56,6 +56,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
+ Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
+ Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>&
factor_rvs) final;
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py
b/tests/python/unittest/test_tir_schedule_utilities.py
index 1596d08..d75bc14 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -36,13 +36,31 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j in T.grid(128, 128):
with T.block("init"):
vi, vj = T.axis.remap("SS", [i, j])
- C[vi, vj] = T.float32(0)
+ C[vi, vj] = 0.0
for k in range(0, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
[email protected]_func
+def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None:
+ A = T.match_buffer(a, (1024, 1024))
+ B = T.match_buffer(b, (1024, 1024))
+ C = T.alloc_buffer((1024, 1024))
+ D = T.match_buffer(d, (1024, 1024))
+ for i, j, k in T.grid(1024, 1024, 1024):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+ for i, j in T.grid(1024, 1024):
+ with T.block("relu"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = T.max(C[vi, vj], 0.0)
+
+
# pylint: enable=no-member,invalid-name,unused-variable
@@ -159,5 +177,27 @@ def test_get_child_blocks():
assert s.get(update) == s.get(blocks[1])
+def test_get_producers():
+ sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
+ block = sch.get_block("relu")
+ (producer,) = sch.get_producers(block)
+ assert tvm.ir.structural_equal(
+ sch.get_sref(producer).stmt,
+ sch.get_sref(sch.get_block("matmul")).stmt,
+ )
+ verify_trace_roundtrip(sch, mod=matmul_relu)
+
+
+def test_get_consumers():
+ sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
+ block = sch.get_block("matmul")
+ (consumer,) = sch.get_consumers(block)
+ assert tvm.ir.structural_equal(
+ sch.get_sref(consumer).stmt,
+ sch.get_sref(sch.get_block("relu")).stmt,
+ )
+ verify_trace_roundtrip(sch, mod=matmul_relu)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))