This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b01ab9e [TensorIR][M2a] CacheRead/Write (#8863)
b01ab9e is described below
commit b01ab9e81e6a9605e6d2dce5b0c81ce551c1839b
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Sep 1 03:59:53 2021 +0800
[TensorIR][M2a] CacheRead/Write (#8863)
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
---
include/tvm/tir/schedule/schedule.h | 22 +
include/tvm/tir/schedule/state.h | 5 +
python/tvm/tir/schedule/schedule.py | 135 ++++
src/tir/schedule/analysis.h | 21 +-
src/tir/schedule/analysis/analysis.cc | 50 +-
src/tir/schedule/concrete_schedule.cc | 21 +
src/tir/schedule/concrete_schedule.h | 4 +
src/tir/schedule/primitive.h | 24 +
src/tir/schedule/primitive/block_annotate.cc | 4 +-
src/tir/schedule/primitive/cache_read_write.cc | 781 +++++++++++++++++++++
src/tir/schedule/schedule.cc | 4 +
src/tir/schedule/state.cc | 18 +
src/tir/schedule/traced_schedule.cc | 23 +
src/tir/schedule/traced_schedule.h | 4 +
src/tir/schedule/transform.cc | 40 ++
src/tir/schedule/transform.h | 29 +
src/tir/schedule/utils.h | 1 +
.../unittest/test_tir_schedule_cache_read_write.py | 677 ++++++++++++++++++
18 files changed, 1840 insertions(+), 23 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 79fed09..33776cb 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
+ /*!
+ * \brief Create a block that reads a buffer region into a read cache. It
requires:
+ * 1) There is at most one block who writes the buffer in the scope.
+ * 2) The scope block have stage-pipeline property.
+ * \param block_rv The consumer block of the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope.
+ * \return The cache stage block.
+ */
+ virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+ const String& storage_scope) = 0;
+ /*!
+ * \brief Create a block that writes a buffer region into a write cache. It
requires:
+ * 1) There is only one block who writes the target buffer.
+ * 2) The scope block have stage-pipeline property.
+ * \param block_rv The producer of the buffer
+ * \param write_buffer_index The index of the buffer in block's write region
+ * \param storage_scope The target storage scope
+ * \return The cache stage block.
+ */
+ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+ const String& storage_scope) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 7cd1b00..35299a3 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -129,6 +129,11 @@ class ScheduleStateNode : public Object {
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
/*!
+ * \brief Recalculate the `affine_binding` flag of the scope block info.
+ * \param scope_sref The sref to the interested scope block.
+ */
+ TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
+ /*!
* \brief Trigger the verification according to the `debug_mask` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the
sref tree.
* 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of
`affine_binding`,
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 9433d01..ac09bdb 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -790,6 +790,141 @@ class Schedule(Object):
########## Schedule: Insert cache stages ##########
+ def cache_read(self, block: BlockRV, read_buffer_index: int,
storage_scope: str) -> BlockRV:
+ """Create a block that reads a buffer region into a read cache. It
requires:
+
+ 1) There is at most one block who write the buffer in the scope.
+
+ 2) The scope block have stage-pipeline property.
+
+ Parameters
+ ----------
+ block : BlockRV
+ The consumer block of the target buffer.
+
+ read_buffer_index: int
+ The index of the buffer in block's read region.
+
+ storage_scope: str
+ The target storage scope.
+
+ Returns
+ -------
+ cached_block : BlockRV
+ The block of the cache stage
+
+ Examples
+ --------
+ Before cache_read, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def before_cache_read(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ Create the schedule and cache_read:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_cache_read)
+ block_b = sch.get_block("B")
+ sch.cache_read(block_b, 0, "local")
+ print(tvm.script.asscript(sch.mod["main"]))
+
+ After applying cache_read, the IR becomes:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def after_cache_read(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ A_local = tir.alloc_buffer((128, 128), scope="local")
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "A_local") as [vi, vj]:
+ A_local[vi, vj] = A[vi, vj]
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A_local[vi, vj] * 2.0
+
+ """
+ return _ffi_api.ScheduleCacheRead( # type: ignore # pylint:
disable=no-member
+ self, block, read_buffer_index, storage_scope
+ )
+
+ def cache_write(self, block: BlockRV, write_buffer_index: int,
storage_scope: str) -> BlockRV:
+ """Create a block that reads a buffer region into a write cache. It
requires:
+
+ 1) There is only one block who write the buffer in the scope.
+
+ 2) The scope block have stage-pipeline property.
+
+ Parameters
+ ----------
+ block : BlockRV
+ The producer block of the target buffer.
+
+ write_buffer_index: int
+ The index of the buffer in block's write region.
+
+ storage_scope: str
+ The target storage scope.
+
+
+ Returns
+ -------
+ cached_block : BlockRV
+ The block of the cache stage
+
+ Examples
+ --------
+ Before cache_write, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def before_cache_write(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ Create the schedule and cache_write:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_cache_write)
+ block_b = sch.get_block("B")
+ sch.cache_write(block_b, 0, "local")
+ print(tvm.script.asscript(sch.mod["main"]))
+
+ After applying cache_write, the IR becomes:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ B_local = tir.alloc_buffer((128, 128), scope="local")
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "A_local") as [vi, vj]:
+ B_local[vi, vj] = A[vi, vj] * 2.0
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = B_local[vi, vj]
+
+ """
+ return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint:
disable=no-member
+ self, block, write_buffer_index, storage_scope
+ )
+
########## Schedule: Compute location ##########
def compute_inline(self, block: BlockRV) -> None:
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 3fa0c63..d4e4728 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self);
const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode*
root_block,
GlobalVar* result_g_var);
+/*!
+ * \brief Get the root node of the sref tree, which is the root block of the
PrimFunc.
+ * \param sref The given sref.
+ * \return The root node of the sref tree which contains the given node.
+ */
+StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
+
/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and
return it
@@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self,
const StmtSRef& block_sr
/******** Block-buffer relation ********/
/*!
- * \brief Get the BlockRealize of the single child block of the block or loop
specified by
- * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple
child blocks
- * \param self The schedule state
- * \param block The queried block
- * \param n The index of the queried buffer
- * \return The buffer of the n-th write region of the block.
+ * \brief Get the n-th read or write buffer of the given block.
+ * \param self The schedule state.
+ * \param block The queried block.
+ * \param n The index of the queried buffer.
+ * \param is_write A boolean flag to indicate querying write buffer or read
buffer.
+ * \return The buffer of the n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
-Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int
n, bool is_write);
/******** Commutative Reducer ********/
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index c9f8ff4..3865781 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self,
const StmtSRef& block_sr
/******** Block-buffer relation ********/
-Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n)
{
- class WriteBufferIndexOutOfRangeError : public ScheduleError {
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int
n, bool is_write) {
+ class BufferIndexOutOfRangeError : public ScheduleError {
public:
- explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int
buffer_index)
- : mod_(std::move(mod)), block_(std::move(block)),
buffer_index_(buffer_index) {}
+ explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int
buffer_index, bool is_write)
+ : mod_(std::move(mod)),
+ block_(std::move(block)),
+ buffer_index_(buffer_index),
+ is_write_(is_write) {}
String FastErrorString() const final {
- return "ScheduleError: The input `buffer_index` is out of range. It is
required to be in "
- "range [0, num_write_regions) where `num_write_regions` is the
number of buffer "
- "regions written by the block.";
+ if (is_write_) {
+ return "ScheduleError: The input `buffer_index` is out of range. It is
required to be in "
+ "range "
+ "[0, num_write_regions) where `num_write_regions` is the number
of buffer regions "
+ "written by the block.";
+ } else {
+ return "ScheduleError: The input `buffer_index` is out of range. It is
required to be in "
+ "range "
+ "[0, num_read_regions) where `num_read_regions` is the number
of buffer regions "
+ "read by the block.";
+ }
}
String DetailRenderTemplate() const final {
std::ostringstream os;
- size_t num_writes = block_->writes.size();
- os << "The block {0} has " << num_writes
- << " write regions, so `buffer_index` is required to be in [0, " <<
num_writes
+ size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
+ std::string access_type = is_write_ ? "write" : "read";
+ os << "The block {0} has " << num << " " << access_type
+ << " regions, so `buffer_index` is required to be in [0, " << num
<< "). However, the input `buffer_index` is " << buffer_index_
- << ", which is out of the expected range";
+ << ", which is out of the expected range.";
return os.str();
}
@@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const
Block& block, int n) {
IRModule mod_;
Block block_;
int buffer_index_;
+ bool is_write_;
};
- if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
- throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
+ const Array<BufferRegion>& access_region = is_write ? block->writes :
block->reads;
+
+ if (n < 0 || static_cast<int>(access_region.size()) <= n) {
+ throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
}
- return block->writes[n]->buffer;
+ return access_region[n]->buffer;
}
/******** Pattern Matcher ********/
@@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const
BufferStore& combiner,
return false;
}
+/******** SRef Tree Related ********/
+StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
+ const StmtSRefNode* p = sref.get();
+ for (; p->parent != nullptr; p = p->parent) {
+ }
+ return GetRef<StmtSRef>(p);
+}
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index cd9aad8..86223e1 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
}
/******** Schedule: Insert cache stages ********/
+
+BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int
read_buffer_index,
+ const String& storage_scope) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index,
storage_scope);
+ TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<BlockRV>(result);
+}
+
+BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int
write_buffer_index,
+ const String& storage_scope) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::CacheWrite(state_, this->GetSRef(block_rv),
write_buffer_index, storage_scope);
+ TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<BlockRV>(result);
+}
+
/******** Schedule: Compute location ********/
void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 0bd902d..e756f9d 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
void Unroll(const LoopRV& loop_rv) override;
/******** Schedule: Insert cache stages ********/
+ BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+ const String& storage_scope) override;
+ BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+ const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index be33c2a..412611a 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef&
loop_sref, const IterVar&
*/
TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
/******** Schedule: Insert cache stages ********/
+/*!
+ * \brief Create a block that reads a buffer region into a read cache. It
requires:
+ * 1) There is at most one block who writes the buffer in the scope.
+ * 2) The scope block have stage-pipeline property.
+ * \param self The state of the schedule
+ * \param block_sref The consumer block of the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope.
+ * \return The cache stage block.
+ */
+TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int
read_buffer_index,
+ const String& storage_scope);
+/*!
+ * \brief Create a block that writes a buffer region into a write cache. It
requires:
+ * 1) There is only one block that writes the target buffer.
+ * 2) The scope block have stage-pipeline property.
+ * \param self The state of the schedule
+ * \param block_sref The producer of the buffer
+ * \param write_buffer_index The index of the buffer in block's write region
+ * \param storage_scope The target storage scope
+ * \return The cache stage block.
+ */
+TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
+ const String& storage_scope);
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
diff --git a/src/tir/schedule/primitive/block_annotate.cc
b/src/tir/schedule/primitive/block_annotate.cc
index 937bc7c..06f7ac3 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "../transform.h"
#include "../utils.h"
namespace tvm {
@@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public
ScheduleError {
void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int
buffer_index, int axis,
int factor, int offset) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
- Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr),
buffer_index);
+ Buffer buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
/*is_write=*/true);
StorageAlignInvalidFactorError::Check(self->mod, factor);
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer,
axis);
NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
new file mode 100644
index 0000000..df54c96
--- /dev/null
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -0,0 +1,781 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Error Classes ********/
+
+class NotSingleWriteBlock : public ScheduleError {
+ public:
+ explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array<StmtSRef>
write_blocks)
+ : mod_(std::move(mod)), buffer_(std::move(buffer)) {
+ ICHECK_GT(write_blocks.size(), 1);
+ write_blocks_.reserve(write_blocks.size());
+ for (const StmtSRef& block_sref : write_blocks) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ write_blocks_.push_back(GetRef<Block>(block));
+ }
+ }
+
+ String FastErrorString() const final {
+ return "ScheduleError: The buffer is allowed to be written by single
block.";
+ }
+
+ String DetailRenderTemplate() const final {
+ size_t k = write_blocks_.size();
+ return "The buffer " + buffer_->name + " is expected to be written by
single block, but got " +
+ std::to_string(k) + " blocks who write it.";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final {
+ return {write_blocks_.begin(), write_blocks_.end()};
+ }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+ Array<Block> write_blocks_;
+};
+
+/******** Helper Functions/Classes ********/
+
+/*! \brief The auxiliary info used for the insertion point and content of the
cache stage. */
+struct CacheStageInfo {
+ /*! \brief The buffer to be read. */
+ Buffer read_buffer;
+ /*! \brief The buffer to be written. */
+ Buffer write_buffer;
+ /*! \brief The buffer allocation to be inserted into the block signature. */
+ Buffer alloc;
+ /*! \brief The AST node whose body is where the cache stage should be
inserted. */
+ StmtSRef loc_sref;
+ /*! \brief The index to insert the cache_read/cache_write stage. */
+ size_t loc_pos;
+ /*! \brief The cache_read/cache_write stage to be inserted. */
+ Stmt cache_stage;
+ /*! \brief The map used for ScheduleStateNode::Replace. */
+ Map<Block, Block> block_reuse;
+};
+
+/*! \brief Return the buffer region realted with the buffer */
+Optional<BufferRegion> GetBufferRegionFromBuffer(const Array<BufferRegion>&
buffer_regions,
+ const Buffer& buffer) {
+ Optional<BufferRegion> res = NullOpt;
+ for (const auto& region : buffer_regions) {
+ if (region->buffer.same_as(buffer)) {
+ ICHECK(!res.defined());
+ res = region;
+ }
+ }
+ return res;
+}
+
+/*!
+ * \brief Create a loop nest that represents cache copy (cache_read /
cache_write) from read buffer
+ * to write buffer.
+ * \note This function will store the stmt with loop nesting to the
CacheStageInfo, but only return
+ * the inside block.
+ * \param cache_region The cached copy region.
+ * \param info The cache stage information, which will be updated in the
function.
+ * \param storage_scope The storage scope of the cached buffer (only used in
naming here)
+ * \returns A block indicating the body of the loop nesting.
+ */
+Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
+ const String& storage_scope) {
+ // loop variables
+ std::vector<Var> loop_vars;
+ // bindings in block realize
+ std::vector<PrimExpr> iter_values;
+ // Create loop vars and block vars' binding_value
+ for (const Range& axis_range : cache_region->region) {
+ Var loop_var("ax" + std::to_string(loop_vars.size()));
+ loop_vars.push_back(loop_var);
+ iter_values.push_back(axis_range->min + loop_var);
+ }
+ // block variables
+ Array<IterVar> block_vars;
+ // block access region for read/write buffers
+ Region access_region;
+ // indices used in block body
+ Array<PrimExpr> access_indices;
+ // Create block vars, block's accessed region and accessing indices
+ for (const PrimExpr& dim : cache_region->buffer->shape) {
+ Var var("v" + std::to_string(access_indices.size()));
+ block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim),
+ /*var=*/var,
+ /*IterVarType=*/kDataPar));
+ access_indices.push_back(var);
+ access_region.push_back(Range::FromMinExtent(var, 1));
+ }
+
+ // Create the body block:
+ // reads = [read_buffer[access_region]]
+ // writes = [write_buffer[access_region]]
+ // write_buffer[access_indices] = read_buffer[access_indices]
+ Block block(
+ /*iter_vars=*/std::move(block_vars),
+ /*reads=*/{BufferRegion(info->read_buffer, access_region)},
+ /*writes=*/{BufferRegion(info->write_buffer, access_region)},
+ /*name_hint=*/cache_region->buffer->name + "_" + storage_scope,
+ /*body=*/
+ BufferStore(info->write_buffer, BufferLoad(info->read_buffer,
access_indices),
+ access_indices),
+ /*init=*/NullOpt,
+ /*alloc_buffers=*/{},
+ /*match_buffers=*/{},
+ /*annotations=*/{});
+ // Create the block realize node
+ Stmt body = BlockRealize(/*values=*/iter_values,
+ /*predicate=*/Bool(true),
+ /*block=*/block);
+ // Create surrounding loops
+ for (size_t i = loop_vars.size(); i >= 1; --i) {
+ body = For(/*loop_var=*/loop_vars[i - 1],
+ /*min=*/0,
+ /*extent=*/cache_region->region[i - 1]->extent,
+ /*kind=*/ForKind::kSerial,
+ /*body=*/body);
+ }
+ info->cache_stage = std::move(body);
+ return block;
+}
+
+/*!
+ * \brief Insert the cache_read/cache_write stage into the specific position
+ * \param stmt A sequence of statements or a single statement that the new
stage is inserted in
+ * \param pos The position where the cache stage is inserted
+ * \param stage The stage to be inserted
+ * \return A SeqStmt, the result after insertion
+ */
+SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
+ if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
+ ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
+ result->seq.insert(result->seq.begin() + pos, stage);
+ return SeqStmt(result);
+ }
+ if (pos == 0) {
+ return SeqStmt({stage, stmt});
+ }
+ ICHECK_EQ(pos, 1);
+ return SeqStmt({stmt, stage});
+}
+
+/*!
+ * \brief Get the only writer block of the input buffer in a given scope block.
+ * \param self The state of the schedule
+ * \param scope_sref The scope block where the write is considered
+ * \param buffer The queried buffer
+ * \return The sref of the only writer of the input buffer in the given scope,
+ * or `NullOpt` if no block writes it in the scope.
+ * \throw NotSingleWriteBlock if there are more than one intrested block.
+ */
+Optional<StmtSRef> GetOnlyWriteBlock(ScheduleState self, const StmtSRef&
scope_sref,
+ const Buffer& buffer) {
+ BlockScope scope = self->GetBlockScope(scope_sref);
+ auto it = scope->buffer_writers.find(buffer);
+ if (it == scope->buffer_writers.end()) {
+ return NullOpt;
+ } else {
+ const Array<StmtSRef>& block_srefs = it->second;
+ ICHECK(!block_srefs.empty());
+ if (block_srefs.size() > 1) {
+ throw NotSingleWriteBlock(self->mod, buffer, block_srefs);
+ }
+ return block_srefs[0];
+ }
+}
+
+/*!
+ * \brief Get the buffer region under the sref tree path [dom_low_inclusive,
dom_high_exclusive)
+ * \param self The state of the schedule.
+ * \param buffer_region The buffer region to be analyzed.
+ * \param block_sref The sref of the block related to the region.
+ * \param dom_low_inclusive The lowest node in the sref tree path.
+ * \param dom_high_exclusive The highest node in the sref tree path.
+ * \return The relaxed buffer region.
+ */
+BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion&
buffer_region,
+ const StmtSRef& block_sref, const StmtSRef&
dom_low_inclusive,
+ const StmtSRef& dom_high_exclusive) {
+ BlockRealize realize = GetBlockRealize(self, block_sref);
+ Map<Var, PrimExpr> binding = GetBindings(realize);
+ const Buffer& buffer = buffer_region->buffer;
+ Array<arith::IntSet> int_sets =
+ arith::EvalSet(Substitute(buffer_region->region, binding),
+ AsIntSet(LoopDomainOfSRefTreePath(
+ /*low_inclusive=*/dom_low_inclusive,
+ /*high_exclusive=*/dom_high_exclusive,
+
/*extra_relax_scope=*/runtime::StorageScope::Create(buffer.scope()))));
+ ICHECK_EQ(buffer_region->region.size(), int_sets.size());
+
+ Region region;
+ region.reserve(int_sets.size());
+ for (size_t i = 0; i < int_sets.size(); ++i) {
+ region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0,
buffer->shape[i])));
+ }
+ return BufferRegion(buffer, region);
+}
+
+/*! \brief Detect the insertion position of the new cache stage */
+class CacheLocDetector : public StmtVisitor {
+ public:
+ /*!
+ * \brief Detect the insertion position of the cache stage, and write the
position into the
+ * CacheStageInfo \param self The state of the schedule \param block_sref
The sref of the unique
+ * writer block of the buffer being applied cache_read or cache_write \param
scope_sref The sref
+ * of the scope block of the cached block \param info The cache stage info.
+ */
+ static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
+ const StmtSRef& scope_sref, CacheStageInfo* info) {
+ std::vector<StmtSRef> related_blocks;
+ for (const Dependency& def :
self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
+ if (def->kind == DepKind::kRAW) {
+ related_blocks.push_back(def->dst);
+ }
+ }
+ if (!related_blocks.empty()) {
+ CacheLocDetector detector(self, block_sref, scope_sref, related_blocks);
+ detector(GetRef<Stmt>(scope_sref->stmt));
+ info->loc_sref = detector.loc_sref_;
+ info->loc_pos = detector.loc_pos_;
+ } else {
+ info->loc_sref = scope_sref;
+ const auto* body =
scope_sref->StmtAs<BlockNode>()->body.as<SeqStmtNode>();
+ info->loc_pos = body == nullptr ? 1 : body->size();
+ }
+ }
+
+ private:
+ /*!
+ * \brief Constructor
+ * \param self The state of the schedule
+ * \param block_sref The sref of the unique writer block of the buffer being
applied cache_read or
+ * cache_write \param scope_sref The sref of the scope block of the cached
block \param
+ * related_blocks Producer blocks for cache_write, or consumer blocks for
cache_read
+ */
+ CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const
StmtSRef& scope_sref,
+ const std::vector<StmtSRef>& related_blocks)
+ : self_(self),
+ block_sref_(block_sref),
+ scope_sref_(scope_sref),
+ related_blocks_(related_blocks) {}
+
+ void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ bool previous_visited_block = visited_block_;
+ bool previous_visited_related = visited_related_;
+ visited_block_ = visited_related_ = false;
+
+ int pos = -1;
+ for (size_t i = 0; i < seq_stmt->size(); ++i) {
+ if (loc_pos_ != -1) {
+ break;
+ }
+ VisitStmt(seq_stmt->seq[i]);
+ // `pos` can be assigned only once when we visited `block_sref`
+ if (visited_block_ && visited_related_ && pos == -1) {
+ // The offset of insert position from the block
+ pos = i;
+ }
+ }
+ visited_block_ = visited_block_ || previous_visited_block;
+ visited_related_ = visited_related_ || previous_visited_related;
+ // Only we visited the writing block and any one of the related blocks
+ // That means that we have found the lowest ancestor
+ // of the block and any one of the related ones
+ if (visited_block_ && visited_related_ && loc_pos_ == -1) {
+ loc_pos_ = pos;
+ }
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ // Only visit the current scope under buffer writer's parent block
+ if (block == scope_sref_->stmt) {
+ // The block vistied is the current parent scope
+ StmtVisitor::VisitStmt_(block);
+ // Handling cache_read for input buffer
+ if (visited_block_ && visited_related_ && !loc_sref_.defined()) {
+ loc_sref_ = self_->stmt2ref.at(block);
+ if (loc_pos_ == -1) {
+ loc_pos_ = 1;
+ }
+ }
+ return;
+ }
+ // Update `visited_block`
+ if (block_sref_->stmt == block) {
+ visited_block_ = true;
+ return;
+ }
+ // Update `visited_related`
+ for (const StmtSRef& related_block : related_blocks_) {
+ if (related_block->stmt == block) {
+ visited_related_ = true;
+ return;
+ }
+ }
+ }
+
+ void VisitStmt_(const ForNode* loop) final {
+ StmtVisitor::VisitStmt_(loop);
+ if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_
!= -1) {
+ loc_sref_ = self_->stmt2ref.at(loop);
+ }
+ }
+
+ private:
+ /*! \brief The schedule class */
+ const ScheduleState self_;
+ /*! \brief The dominate block which write the buffer */
+ const StmtSRef& block_sref_;
+ /*! \brief The parent scope of the dominate block */
+ const StmtSRef& scope_sref_;
+ /*! \brief Producer blocks for cache_write and consumer blocks for
cache_read */
+ const std::vector<StmtSRef>& related_blocks_;
+ /*! \brief The flag whether we have visited the dominate block */
+ bool visited_block_{false};
+ /*! \brief The flag whether we have visited at least one related blocks */
+ bool visited_related_{false};
+ /*! \brief The AST node whose body is where the cache stage should be
inserted */
+ StmtSRef loc_sref_{nullptr};
+ /*! \brief The index to insert the cache_read/cache_write stage */
+ int loc_pos_{-1};
+};
+
+/*! \brief Mutator for CacheRead. */
+class CacheReadRewriter : public StmtExprMutator {
+ public:
+ /*!
+ * \brief Rewrite the AST and add a cache_read stage with the information
provided
+ * \param scope_sref The parent scope of this mutation
+ * \param info The cache stage information
+ * \return The new AST rooting at the original parent scope
+ */
+ static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) {
+ CacheReadRewriter rewriter(scope_sref, info);
+ return rewriter(GetRef<Stmt>(scope_sref->stmt));
+ }
+
+ private:
+ explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info)
+ : scope_sref_(scope_sref), info_(info) {}
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ Stmt stmt = StmtMutator::VisitStmt_(loop);
+ // Check the insertion point
+ if (loop == info_->loc_sref->stmt) {
+ // Insert cache stage into the loop if it is the right place
+ ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
+ n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+ stmt = Stmt(n);
+ }
+ return stmt;
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ Block old_stmt = GetRef<Block>(block);
+ // We don't mutate the block which generates info->read_buffer
+ if (block != scope_sref_->stmt &&
+ GetBufferRegionFromBuffer(block->writes,
info_->read_buffer).defined()) {
+ return std::move(old_stmt);
+ }
+ // Mutate the body
+ Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
+ // Check the insertion point
+ if (block == info_->loc_sref->stmt) {
+ // Insert cache stage into the block if it is the right place
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+ stmt = Block(n);
+ }
+ // Check if it is the block corresponding to the parent scope
+ if (block == scope_sref_->stmt) {
+ // If so, put buffer allocation on the parent scope
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->alloc_buffers.push_back(info_->alloc);
+ stmt = Block(n);
+ } else {
+ // Otherwise, update read regions and match_buffers
+ Array<BufferRegion> reads =
+ ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
+ Array<MatchBufferRegion> match_buffers =
+ ReplaceBuffer(block->match_buffers, info_->read_buffer,
info_->write_buffer);
+ if (!reads.same_as(block->reads) ||
!match_buffers.same_as(block->match_buffers)) {
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->reads = std::move(reads);
+ n->match_buffers = std::move(match_buffers);
+ stmt = Block(n);
+ }
+ }
+ info_->block_reuse.Set(old_stmt, stmt);
+ return std::move(stmt);
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+ if (load->buffer.same_as(info_->read_buffer)) {
+ ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
+ n->buffer = info_->write_buffer;
+ return PrimExpr(n);
+ }
+ return ExprMutator::VisitExpr_(load);
+ }
+
+ PrimExpr VisitExpr_(const LoadNode* load) final {
+ if (load->buffer_var.same_as(info_->read_buffer->data)) {
+ ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
+ n->buffer_var = info_->write_buffer->data;
+ return PrimExpr(n);
+ }
+ return ExprMutator::VisitExpr_(load);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (op == info_->read_buffer->data.get()) {
+ return info_->write_buffer->data;
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ /*! \brief The parent scope of the insertion */
+ const StmtSRef& scope_sref_;
+ /*! \brief The info for inserting cache stage */
+ CacheStageInfo* info_;
+};
+
+/*! \brief Mutator for CacheWrite */
+class CacheWriteRewriter : public StmtExprMutator {
+ public:
+ /*!
+ * \brief Rewrite the AST and add a cache_write stage with the information
provided.
+ * \param scope_sref The parent scope of this mutation.
+ * \param writer_block_sref The only writer block in the scope.
+ * \param info The cache stage information.
+ * \return The new AST rooting at the original parent scope.
+ */
+ static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef&
writer_block_sref,
+ CacheStageInfo* info) {
+ CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info);
+ return rewriter(GetRef<Stmt>(scope_sref->stmt));
+ }
+
+ private:
+ explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef&
writer_block_sref,
+ CacheStageInfo* info)
+ : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref),
info_(info) {}
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ Stmt stmt = StmtMutator::VisitStmt_(loop);
+ // Check the insertion point
+ if (loop == info_->loc_sref->stmt) {
+ // Insert cache stage into the loop if it is the right place
+ ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
+ n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+ stmt = Stmt(n);
+ }
+ return stmt;
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ Block old_stmt = GetRef<Block>(block);
+ // We only mutate the block which generates info->write_buffer
+ if (block != writer_block_sref_->stmt && block != scope_sref_->stmt &&
!under_writer_block_) {
+ return std::move(old_stmt);
+ }
+
+ // Mutate the body
+ bool under_scope = under_writer_block_ || block ==
writer_block_sref_->stmt;
+ std::swap(under_scope, under_writer_block_);
+ Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
+ std::swap(under_scope, under_writer_block_);
+
+ // Find the insertion point
+ if (block == info_->loc_sref->stmt) {
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+ stmt = Block(n);
+ }
+ // Put buffer allocation on the parent scope
+ if (block == scope_sref_->stmt) {
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->alloc_buffers.push_back(info_->alloc);
+ stmt = Block(n);
+ } else {
+ // Since cache_write changes the block, we need to update the buffer it
writes
+ auto writes = ReplaceBuffer(block->writes, info_->write_buffer,
info_->read_buffer);
+ auto reads = ReplaceBuffer(block->reads, info_->write_buffer,
info_->read_buffer);
+ auto match_buffers =
+ ReplaceBuffer(block->match_buffers, info_->write_buffer,
info_->read_buffer);
+ if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
+ !match_buffers.same_as(block->match_buffers)) {
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->writes = std::move(writes);
+ n->reads = std::move(reads);
+ n->match_buffers = std::move(match_buffers);
+ stmt = Block(n);
+ }
+ }
+ info_->block_reuse.Set(old_stmt, stmt);
+ return std::move(stmt);
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* store) final {
+ BufferStore stmt = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+ if (stmt->buffer.same_as(info_->write_buffer)) {
+ auto n = CopyOnWrite(stmt.get());
+ n->buffer = info_->read_buffer;
+ return Stmt(n);
+ } else {
+ return std::move(stmt);
+ }
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+ if (load->buffer.same_as(info_->write_buffer)) {
+ ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
+ n->buffer = info_->read_buffer;
+ return PrimExpr(n);
+ }
+ return ExprMutator::VisitExpr_(load);
+ }
+
+ PrimExpr VisitExpr_(const LoadNode* load) final {
+ if (load->buffer_var.same_as(info_->write_buffer->data)) {
+ ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
+ n->buffer_var = info_->read_buffer->data;
+ return PrimExpr(n);
+ }
+ return ExprMutator::VisitExpr_(load);
+ }
+
+ Stmt VisitStmt_(const StoreNode* store) final {
+ if (store->buffer_var.same_as(info_->write_buffer->data)) {
+ ObjectPtr<StoreNode> n = make_object<StoreNode>(*store);
+ n->buffer_var = info_->read_buffer->data;
+ return Stmt(n);
+ }
+ return StmtMutator::VisitStmt_(store);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (op == info_->write_buffer->data.get()) {
+ return info_->read_buffer->data;
+ }
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ /*! \brief The parent scope of the insertion. */
+ const StmtSRef& scope_sref_;
+ /*! \brief The parent scope of the insertion. */
+ const StmtSRef& writer_block_sref_;
+ /*! \brief The info for inserting cache stage. */
+ CacheStageInfo* info_;
+ /*! \brief Whether the current node is under the given block. */
+ bool under_writer_block_{false};
+};
+
+/******** Implementation ********/
+
+StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int
read_buffer_index,
+ const String& storage_scope) {
+ /*!
+ * Check:
+ * - The index is in the array of block reading region
+ * - There is at most one block who write the buffer in the scope
+ *
+ * Mutate:
+ * - Allocate new cache buffer under the current scope.
+ * - Find the lowest ancestor of the block and ANY ONE of the consumers
blocks.
+ * - Copy the buffer with the consumed region.
+ */
+
+ // Step 1. Check index, getting the target buffer and the parent scope
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ Buffer read_buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index,
/*is_write=*/false);
+ StmtSRef scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
+ const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+
+ // Step 2. Creat CacheStageInfo
+ CacheStageInfo info;
+ info.read_buffer = read_buffer;
+ // Create the corresponding buffer to be written, i.e. result of cache_read
+ info.write_buffer = WithScope(read_buffer, storage_scope);
+ // Create the corresponding buffer allocation
+ info.alloc = info.write_buffer;
+
+ // Step 3. Update cache stage info.
+ BufferRegion cache_region{nullptr};
+ if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self,
scope_sref, read_buffer)) {
+ // Case 1. The buffer is written inside the block.
+ StmtSRef write_block_sref = _write_block_sref.value();
+ const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block,
write_block_sref);
+ // Find the producing region
+ BufferRegion region = GetBufferRegionFromBuffer(write_block->writes,
read_buffer).value();
+ StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+ // Detect insert position
+ CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+ cache_region = RelaxBufferRegion(self, region, write_block_sref,
parent_sref, info.loc_sref);
+ } else {
+ // Case 2. The buffer is the input block for the scope.
+ info.loc_sref = scope_sref;
+ info.loc_pos = 0;
+ if (Optional<BufferRegion> region =
+ GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
+ cache_region = region.value();
+ } else {
+ cache_region = BufferRegion::FullRegion(read_buffer);
+ }
+ }
+
+ // Step 4. Making new cache stage block and rewrite readers.
+ Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region,
/*info=*/&info,
+ /*storage_scope=*/storage_scope);
+ Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref,
/*info=*/&info);
+
+ // Step 5. Replacing and updating flags.
+ self->Replace(scope_sref, new_scope, info.block_reuse);
+ StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
+ self->UpdateAffineFlag(result_block_sref);
+ BlockInfo& block_info = self->block_info[result_block_sref];
+ block_info.region_cover = true;
+ block_info.scope->stage_pipeline = true;
+ return result_block_sref;
+}
+
+StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int
write_buffer_index,
+ const String& storage_scope) {
+ /*!
+ * Check:
+ * - The index is in the array of block reading region
+ * - There is only one block who write the buffer in the scope
+ *
+ * Mutate:
+ * - Allocate new cache buffer under the current scope.
+ * - Find the lowest ancestor of the block and ANY ONE of the producer
blocks.
+ * - Copy the buffer with the consumed region.
+ */
+ // Step 1. Checking index, getting the target buffer and the parent scope
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+ Buffer write_buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index,
/*is_write=*/true);
+ StmtSRef scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
+
+ // Step 2. Creating CacheStageInfo
+ CacheStageInfo info;
+ info.read_buffer = WithScope(write_buffer, storage_scope);
+ // Create the corresponding buffer to be written, i.e. result of cache_write
+ info.write_buffer = write_buffer;
+ // Create the corresponding buffer allocation
+ info.alloc = info.read_buffer;
+
+ // Step 3. Check the only writer block.
+ ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref,
write_buffer).get());
+
+ // Step 4. Find the producing region and insert position
+ BufferRegion region = GetBufferRegionFromBuffer(block->writes,
write_buffer).value();
+ StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
+ // Detect insert position
+ CacheLocDetector::Detect(self, block_sref, scope_sref, &info);
+ BufferRegion cache_region =
+ RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
+
+ // Step 5. Making new cache stage block and rewrite readers.
+ Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region,
/*info=*/&info,
+ /*storage_scope=*/storage_scope);
+ Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref,
+
/*writer_block_sref=*/block_sref, /*info=*/&info);
+
+ // Step 6. Replacing and updating flags.
+ self->Replace(scope_sref, new_scope, info.block_reuse);
+ StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get());
+ self->UpdateAffineFlag(result_block_sref);
+ BlockInfo& block_info = self->block_info[result_block_sref];
+ block_info.region_cover = true;
+ block_info.scope->stage_pipeline = true;
+ return result_block_sref;
+}
+
+/******** Instruction Registration ********/
+
+struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
+ static constexpr const char* kName = "CacheRead";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 2;
+ static constexpr size_t kNumDecisions = 0;
+
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer
read_buffer_index,
+ String storage_scope) {
+ return sch->CacheRead(block, read_buffer_index->value, storage_scope);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block, Integer
read_buffer_index,
+ String storage_scope) {
+ PythonAPICall py("cache_read");
+ py.Input("block", block);
+ py.Input("read_buffer_index", read_buffer_index->value);
+ py.Input("storage_scope", storage_scope);
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct CacheWriteTraits : public UnpackedInstTraits<CacheWriteTraits> {
+ static constexpr const char* kName = "CacheWrite";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 2;
+ static constexpr size_t kNumDecisions = 0;
+
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer
write_buffer_index,
+ String storage_scope) {
+ return sch->CacheWrite(block, write_buffer_index->value, storage_scope);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block, Integer
write_buffer_index,
+ String storage_scope) {
+ PythonAPICall py("cache_write");
+ py.Input("block", block);
+ py.Input("write_buffer_index", write_buffer_index->value);
+ py.Input("storage_scope", storage_scope);
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits);
+TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits);
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index d24cdc6..fd30b02 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -141,6 +141,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method<Schedule>(&ScheduleNode::Bind);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method<Schedule>(&ScheduleNode::Unroll);
/******** (FFI) Insert cache stages ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead")
+ .set_body_method<Schedule>(&ScheduleNode::CacheRead);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite")
+ .set_body_method<Schedule>(&ScheduleNode::CacheWrite);
/******** (FFI) Compute location ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
.set_body_method<Schedule>(&ScheduleNode::ComputeInline);
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 9a9b974..799806b 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -1029,6 +1029,24 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState&
self, const StmtSRef& bl
Bool(info.scope->stage_pipeline)};
}
+TVM_DLL void ScheduleStateNode::UpdateAffineFlag(const StmtSRef& scope_sref) {
+ auto it = this->block_info.find(scope_sref);
+ ICHECK(it != this->block_info.end()) << "Cannot find the block info of the
given block.";
+ BlockInfo& info = it->second;
+
+ bool is_root_block = scope_sref->parent == nullptr;
+ if (is_root_block) {
+ info.affine_binding = true;
+ } else {
+ BlockRealize realize = GetBlockRealize(GetRef<ScheduleState>(this),
scope_sref);
+ arith::Analyzer analyzer;
+ StmtSRef parent_sref = GetRef<StmtSRef>(scope_sref->parent);
+ info.affine_binding = IsAffineBinding(/*realize=*/realize,
+
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref),
+ /*analyzer=*/&analyzer);
+ }
+}
+
/**************** FFI ****************/
TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index af4a658..f429a91 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -166,6 +166,29 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) {
}
/******** Schedule: Insert cache stages ********/
+BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int
read_buffer_index,
+ const String& storage_scope) {
+ BlockRV result = ConcreteScheduleNode::CacheRead(block_rv,
read_buffer_index, storage_scope);
+
+ static const InstructionKind& kind = InstructionKind::Get("CacheRead");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(read_buffer_index),
storage_scope},
+ /*outputs=*/{result}));
+ return result;
+}
+
+BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int
write_buffer_index,
+ const String& storage_scope) {
+ BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv,
write_buffer_index, storage_scope);
+
+ static const InstructionKind& kind = InstructionKind::Get("CacheWrite");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(write_buffer_index),
storage_scope},
+ /*outputs=*/{result}));
+ return result;
+}
/******** Schedule: Compute location ********/
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 48dadbc..a6b5251 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -71,6 +71,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) final;
void Unroll(const LoopRV& loop_rv) final;
/******** Schedule: Insert cache stages ********/
+ BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+ const String& storage_scope) final;
+ BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+ const String& storage_scope) final;
/******** Schedule: Compute location ********/
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index f27e0f6..da376fd 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -19,6 +19,8 @@
#include "./transform.h"
+#include "./utils.h"
+
namespace tvm {
namespace tir {
@@ -31,5 +33,43 @@ Block WithAnnotation(const BlockNode* block, const String&
attr_key, const Objec
return Block(new_block);
}
+/******** Buffer Related ********/
+Buffer WithScope(const Buffer& buffer, const String& scope) {
+ ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+ ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
+ const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation,
PointerTypeNode);
+ new_var->type_annotation = PointerType(ptr_type->element_type, scope);
+ new_buffer->data = Var(new_var->name_hint + "_" + scope,
new_var->type_annotation);
+ new_buffer->name = buffer->name + "_" + scope;
+ return Buffer(new_buffer);
+}
+
+Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer&
source,
+ const Buffer& target) {
+ regions.MutateByApply([&source, &target](BufferRegion region) ->
BufferRegion {
+ if (region->buffer.same_as(source)) {
+ ObjectPtr<BufferRegionNode> n =
make_object<BufferRegionNode>(*region.get());
+ n->buffer = target;
+ return BufferRegion(n);
+ }
+ return region;
+ });
+ return regions;
+}
+
+Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers,
const Buffer& source,
+ const Buffer& target) {
+ match_buffers.MutateByApply([&source,
+ &target](MatchBufferRegion match_buffer) ->
MatchBufferRegion {
+ if (match_buffer->source->buffer.same_as(source)) {
+ ObjectPtr<MatchBufferRegionNode> n =
make_object<MatchBufferRegionNode>(*match_buffer.get());
+ n->source = BufferRegion(target, n->source->region);
+ return MatchBufferRegion(n);
+ }
+ return match_buffer;
+ });
+ return match_buffers;
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 5348382..85cce9d 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -35,6 +35,35 @@ namespace tir {
*/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const
ObjectRef& attr_value);
+/******** Buffer Related ********/
+
+/*!
+ * \brief Create a new buffer by changing the storage scope.
+ * \param buffer The given buffer.
+ * \param scope The target storage scope.
+ * \return The new buffer with target storage scope.
+ */
+Buffer WithScope(const Buffer& buffer, const String& scope);
+
+/*!
+ * \brief Replaces the buffer within the specific sequence of regions
+ * \param regions The regions whose buffers are to be replaced
+ * \param source The buffer to be replaced
+ * \param target The buffer to be replaced to
+ * \return The new sequence of regions after replacement
+ */
+Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer&
source,
+ const Buffer& target);
+
+/*!
+ * \brief Replaces the buffer within the specific sequence of match_buffers
+ * \param match_buffers The match_buffers whose buffers are to be replaced
+ * \param source The buffer to be replaced
+ * \param target The buffer to be replaced to
+ * \return The new sequence of match_buffers after replacement
+ */
+Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers,
const Buffer& source,
+ const Buffer& target);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index 8ccf8da..c2f4301 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -42,6 +42,7 @@
#include "./error.h"
#include "./instruction_traits.h"
#include "./primitive.h"
+#include "./transform.h"
namespace tvm {
namespace tir {
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py
b/tests/python/unittest/test_tir_schedule_cache_read_write.py
new file mode 100644
index 0000000..d7eb8d8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -0,0 +1,677 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+########## Function before schedule ##########
+
+
[email protected]
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def access_under_scope(b: ty.handle, c: ty.handle) -> None:
+ A = tir.alloc_buffer((128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+
+ with tir.block([8, 8], "scope") as [i, j]:
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "A") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ A[vi, vj] = 1.0
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ B[vi, vj] = A[vi, vj] + 1.0
+
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) ->
None:
+ A = tir.match_buffer(a, (128, 128), dtype="float16")
+ B = tir.match_buffer(b, (128, 128), dtype="float16")
+ C = tir.match_buffer(c, (128, 128), dtype="float16")
+ D = tir.match_buffer(d, (128, 128), dtype="float16")
+
+ with tir.block([128, 128], "load_store") as [vi, vj]:
+ tir.reads(A[vi, vj])
+ tir.writes(D[vi, vj])
+ D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj)
+ with tir.block([8, 8], "opaque") as [vi, vj]:
+ tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ B.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A.data,
+ vi * 2048 + vj * 16,
+ 128,
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+ with tir.block([8, 8], "match_buffer") as [vi, vj]:
+ tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ A0 = tir.match_buffer(
+ A[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ C0 = tir.match_buffer(
+ C[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ C0.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A0.data,
+ A0.elem_offset,
+ A0.strides[0],
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+
+
[email protected]
+def func_multi_consumer() -> None:
+ A = tir.alloc_buffer((128))
+ B = tir.alloc_buffer((128))
+ C = tir.alloc_buffer((128))
+ for i in tir.grid(8):
+ for j in tir.grid(16):
+ with tir.block([128], "A") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ A[vi] = 1.0
+ for j in tir.grid(16):
+ with tir.block([128], "B") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ B[vi] = A[vi] + 1.0
+ for i in tir.grid(128):
+ with tir.block([128], "C") as [vi]:
+ C[vi] = A[vi]
+
+
[email protected]
+def func_multi_producer() -> None:
+ A = tir.alloc_buffer((128))
+ B = tir.alloc_buffer((128))
+ with tir.block([128], "A0") as [vi]:
+ A[vi] = 1.0
+ with tir.block([128], "A1") as [vi]:
+ A[vi] = 2.0
+ with tir.block([128], "B") as [vi]:
+ B[vi] = A[vi]
+
+
+########## Expected function after cache_read ##########
+
+
[email protected]
+def cache_read_elementwise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ A_global = tir.alloc_buffer((128, 128))
+ B_local = tir.alloc_buffer((128, 128), scope="local")
+ with tir.block([128, 128], "A_global") as [vi, vj]:
+ A_global[vi, vj] = A[vi, vj]
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A_global[vi, vj] * 2.0
+ with tir.block([128, 128], "B_local") as [vi, vj]:
+ B_local[vi, vj] = B[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B_local[vi, vj] + 1.0
+
+
[email protected]
+def cache_read_under_scope(b: ty.handle, c: ty.handle) -> None:
+ A = tir.alloc_buffer((128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ A_global = tir.alloc_buffer((128, 128))
+
+ with tir.block([8, 8], "scope") as [i, j]:
+ A_local = tir.alloc_buffer((128, 128), scope="local")
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "A") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ A[vi, vj] = 1.0
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "A_local") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ A_local[vi, vj] = A[vi, vj]
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ B[vi, vj] = A_local[vi, vj] + 1.0
+ with tir.block([128, 128], "A_global") as [vi, vj]:
+ A_global[vi, vj] = A[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = A_global[vi, vj] * 2.0
+
+
[email protected]
+def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d:
ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), dtype="float16")
+ B = tir.match_buffer(b, (128, 128), dtype="float16")
+ C = tir.match_buffer(c, (128, 128), dtype="float16")
+ D = tir.match_buffer(d, (128, 128), dtype="float16")
+ A_global = tir.alloc_buffer((128, 128), dtype="float16")
+
+ with tir.block([128, 128], "A_global") as [vi, vj]:
+ A_global[vi, vj] = A[vi, vj]
+ with tir.block([128, 128], "load_store") as [vi, vj]:
+ tir.reads(A_global[vi, vj])
+ tir.writes(D[vi, vj])
+ D.data[vi * 128 + vj] = tir.load("float16", A_global.data, vi * 128 +
vj)
+ with tir.block([8, 8], "opaque") as [vi, vj]:
+ tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ B.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A_global.data,
+ vi * 2048 + vj * 16,
+ 128,
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+ with tir.block([8, 8], "match_buffer") as [vi, vj]:
+ tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ A0 = tir.match_buffer(
+ A_global[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ C0 = tir.match_buffer(
+ C[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ C0.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A0.data,
+ A0.elem_offset,
+ A0.strides[0],
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+
+
[email protected]
+def cache_read_multi_consumer() -> None:
+ A = tir.alloc_buffer((128))
+ B = tir.alloc_buffer((128))
+ C = tir.alloc_buffer((128))
+ A_global = tir.alloc_buffer((128))
+ for i in tir.grid(8):
+ for j in tir.grid(16):
+ with tir.block([128], "A") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ A[vi] = 1.0
+ for j in tir.grid(16):
+ with tir.block([128], "A") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ A_global[vi] = A[vi]
+ for j in tir.grid(16):
+ with tir.block([128], "B") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ B[vi] = A_global[vi] + 1.0
+
+ for i in tir.grid(128):
+ with tir.block([128], "C") as [vi]:
+ C[vi] = A_global[vi]
+
+
[email protected]
+def continuous_cache_read(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ B_shared = tir.alloc_buffer((128, 128), scope="shared")
+ B_local = tir.alloc_buffer((128, 128), scope="local")
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "B_shared") as [vi, vj]:
+ B_shared[vi, vj] = B[vi, vj]
+ with tir.block([128, 128], "B_local") as [vi, vj]:
+ B_local[vi, vj] = B_shared[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B_local[vi, vj] + 1.0
+
+
+########## Expected function after cache_write ##########
+
+
[email protected]
+def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ B_global = tir.alloc_buffer((128, 128), scope="local")
+ C_local = tir.alloc_buffer((128, 128))
+ with tir.block([128, 128], "B_global") as [vi, vj]:
+ B_global[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = B_global[vi, vj]
+ with tir.block([128, 128], "C_local") as [vi, vj]:
+ C_local[vi, vj] = B[vi, vj] + 1.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = C_local[vi, vj]
+
+
[email protected]
+def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None:
+ A = tir.alloc_buffer((128, 128))
+ B = tir.match_buffer(b, (128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ A_global = tir.alloc_buffer((128, 128))
+
+ with tir.block([8, 8], "scope") as [i, j]:
+ A_local = tir.alloc_buffer((128, 128), scope="local")
+ B_global = tir.alloc_buffer((128, 128))
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "A_local") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ A_local[vi, vj] = 1.0
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "A") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ A_global[vi, vj] = A_local[vi, vj]
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "B_global") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ B_global[vi, vj] = A_global[vi, vj] + 1.0
+ for x, y in tir.grid(16, 16):
+ with tir.block([128, 128], "B_global") as [vi, vj]:
+ tir.bind(vi, i * 16 + x)
+ tir.bind(vj, j * 16 + y)
+ B[vi, vj] = B_global[vi, vj]
+ with tir.block([128, 128], "A_global") as [vi, vj]:
+ A[vi, vj] = A_global[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d:
ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), dtype="float16")
+ B = tir.match_buffer(b, (128, 128), dtype="float16")
+ C = tir.match_buffer(c, (128, 128), dtype="float16")
+ D = tir.match_buffer(d, (128, 128), dtype="float16")
+ D_global = tir.alloc_buffer((128, 128), dtype="float16")
+ B_global = tir.alloc_buffer((128, 128), dtype="float16")
+ C_global = tir.alloc_buffer((128, 128), dtype="float16")
+
+ with tir.block([128, 128], "load_store") as [vi, vj]:
+ tir.reads(A[vi, vj])
+ tir.writes(D_global[vi, vj])
+ D_global.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 +
vj)
+ with tir.block([8, 8], "opaque") as [vi, vj]:
+ tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ B_global.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A.data,
+ vi * 2048 + vj * 16,
+ 128,
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+ with tir.block([8, 8], "match_buffer") as [vi, vj]:
+ tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ tir.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ A0 = tir.match_buffer(
+ A[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ C0 = tir.match_buffer(
+ C_global[
+ vi * 16 : vi * 16 + 16,
+ vj * 16 : vj * 16 + 16,
+ ],
+ (16, 16),
+ "float16",
+ strides=[128, 1],
+ offset_factor=1,
+ )
+ tir.evaluate(
+ tir.tvm_load_matrix_sync(
+ C0.data,
+ 16,
+ 16,
+ 16,
+ vi * 8 + vj,
+ tir.tvm_access_ptr(
+ tir.type_annotation(dtype="float16"),
+ A0.data,
+ A0.elem_offset,
+ A0.strides[0],
+ 1,
+ dtype="handle",
+ ),
+ 128,
+ "row_major",
+ dtype="handle",
+ )
+ )
+
+ with tir.block([128, 128], "D") as [vi, vj]:
+ D[vi, vj] = D_global[vi, vj]
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = B_global[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = C_global[vi, vj]
+
+
[email protected]
+def cache_write_multi_consumer() -> None:
+ A = tir.alloc_buffer((128))
+ B = tir.alloc_buffer((128))
+ C = tir.alloc_buffer((128))
+ A_global = tir.alloc_buffer((128))
+ for i in tir.grid(8):
+ for j in tir.grid(16):
+ with tir.block([128], "A_global") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ A_global[vi] = 1.0
+ for j in tir.grid(16):
+ with tir.block([128], "A") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ A[vi] = A_global[vi]
+ for j in tir.grid(16):
+ with tir.block([128], "B") as [vi]:
+ tir.bind(vi, i * 16 + j)
+ B[vi] = A[vi] + 1.0
+
+ for i in tir.grid(128):
+ with tir.block([128], "C") as [vi]:
+ C[vi] = A[vi]
+
+
[email protected]
+def continuous_cache_write(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ B_shared = tir.alloc_buffer((128, 128), scope="shared")
+ B_local = tir.alloc_buffer((128, 128), scope="local")
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B_local[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B_shared[vi, vj] = B_local[vi, vj]
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = B_shared[vi, vj]
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
+########## Testcases for cache_read ##########
+
+
+def test_cache_read_elementwise():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ block_c = sch.get_block("C")
+ cached_a = sch.cache_read(block_b, 0, "global")
+ cached_b = sch.cache_read(block_c, 0, "local")
+ assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
+ assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
+ assert sch.get(block_b) == sch.get(sch.get_block("B"))
+ assert sch.get(block_c) == sch.get(sch.get_block("C"))
+ tvm.ir.assert_structural_equal(cache_read_elementwise, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_read_under_scope():
+ sch = tir.Schedule(access_under_scope, debug_mask="all")
+ block_b = sch.get_block("B")
+ block_c = sch.get_block("C")
+ sch.cache_read(block_b, 0, "local")
+ sch.cache_read(block_c, 0, "global")
+ tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=access_under_scope)
+
+
+def test_cache_read_opaque_access():
+ sch = tir.Schedule(opaque_access, debug_mask="all")
+ block = sch.get_block("load_store")
+ sch.cache_read(block, 0, "global")
+ tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=opaque_access)
+
+
+def test_cache_read_location():
+ sch = tir.Schedule(func_multi_consumer, debug_mask="all")
+ block_b = sch.get_block("B")
+ sch.cache_read(block_b, 0, "global")
+ tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
+
+
+def test_continuous_cache_read():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_c = sch.get_block("C")
+ sch.cache_read(block_c, 0, "shared")
+ sch.cache_read(block_c, 0, "local")
+ tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_read_fail_multi_producer():
+ sch = tir.Schedule(func_multi_producer, debug_mask="all")
+ block_b = sch.get_block("B")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.cache_read(block_b, 0, "global")
+
+
+def test_cache_read_fail_index_out_of_bound():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.cache_read(block_b, 1, "global")
+
+
+########## Testcases for cache_write ##########
+
+
+def test_cache_write_elementwise():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ block_c = sch.get_block("C")
+ cached_b = sch.cache_write(block_b, 0, "local")
+ cached_c = sch.cache_write(block_c, 0, "global")
+ assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
+ assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
+ assert sch.get(block_b) == sch.get(sch.get_block("B"))
+ assert sch.get(block_c) == sch.get(sch.get_block("C"))
+ tvm.ir.assert_structural_equal(cache_write_elementwise, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_write_under_scope():
+ sch = tir.Schedule(access_under_scope, debug_mask="all")
+ block_a = sch.get_block("A")
+ block_b = sch.get_block("B")
+ block_scope = sch.get_block("scope")
+ sch.cache_write(block_a, 0, "local")
+ sch.cache_write(block_b, 0, "global")
+ sch.cache_write(block_scope, 0, "global")
+ tvm.ir.assert_structural_equal(cache_write_under_scope, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=access_under_scope)
+
+
+def test_cache_write_opaque_access():
+ sch = tir.Schedule(opaque_access, debug_mask="all")
+ block_store = sch.get_block("load_store")
+ block_opaque = sch.get_block("opaque")
+ block_match_buffer = sch.get_block("match_buffer")
+ sch.cache_write(block_store, 0, "global")
+ sch.cache_write(block_opaque, 0, "global")
+ sch.cache_write(block_match_buffer, 0, "global")
+ tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=opaque_access)
+
+
+def test_cache_write_location():
+ sch = tir.Schedule(func_multi_consumer, debug_mask="all")
+ block_a = sch.get_block("A")
+ sch.cache_write(block_a, 0, "global")
+ tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
+
+
+def test_continuous_cache_write():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ sch.cache_write(block_b, 0, "shared")
+ sch.cache_write(block_b, 0, "local")
+ tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_write_fail_multi_producer():
+ sch = tir.Schedule(func_multi_producer, debug_mask="all")
+ block_a0 = sch.get_block("A0")
+ block_a1 = sch.get_block("A1")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.cache_write(block_a0, 0, "global")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.cache_write(block_a1, 0, "global")
+
+
+def test_cache_write_fail_index_out_of_bound():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block_b = sch.get_block("B")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.cache_write(block_b, 1, "global")
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))