This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 12a0f3edcf [TIR] Add schedule primitive ReIndex (#11515)
12a0f3edcf is described below
commit 12a0f3edcf8295288f4aa9ec3dbb6771c3a1a301
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jun 2 14:34:23 2022 -0700
[TIR] Add schedule primitive ReIndex (#11515)
---
include/tvm/tir/schedule/schedule.h | 13 +
python/tvm/tir/schedule/schedule.py | 73 ++++
src/tir/schedule/concrete_schedule.cc | 10 +
src/tir/schedule/concrete_schedule.h | 2 +
src/tir/schedule/primitive.h | 15 +
src/tir/schedule/primitive/cache_read_write.cc | 468 +++++++++++++++++++++
src/tir/schedule/schedule.cc | 5 +
src/tir/schedule/traced_schedule.cc | 12 +
src/tir/schedule/traced_schedule.h | 2 +
src/tir/schedule/transform.cc | 26 ++
src/tir/schedule/transform.h | 21 +
tests/python/unittest/test_tir_schedule_reindex.py | 203 +++++++++
12 files changed, 850 insertions(+)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 48014280a5..68900e107d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
+ /*!
+ * \brief Create a block that read/write a buffer region into a read/write
cache with reindexing.
+ * The layout of the cache will be the same as by the iterators of the block
that reads/writes the
+ * buffer. It requires:
+ * 1) There is only one block who reads/writes the target buffer
+ * 2) There is only one buffer load/store of this buffer in the block
+ * \param block_rv The block operates on the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write
region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \return The reindex stage block.
+ */
+ virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index f86228848b..4179088aa5 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1056,6 +1056,79 @@ class Schedule(Object):
self, block, write_buffer_index, storage_scope
)
+ @type_checked
+ def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type:
str) -> BlockRV:
+ """Create a block that read/write a buffer region into a read/write
cache with reindexing.
+ The layout of the cache will be the same as by the iterators of the
block that reads/writes
+ the buffer. It requires:
+ 1) There is only one block who reads/writes the target buffer
+ 2) There is only one buffer load/store of this buffer in the block
+
+ Parameters
+ ----------
+ block: BlockRV
+ The block that accesses the target buffer
+ buffer_index: int
+ The index of the buffer in block's read or write region
+ buffer_index_type : str
+ Type of the buffer index, "read" or "write"
+
+ Returns
+ -------
+ reindex_block : BlockRV
+ The block of the reindex stage
+
+ Examples
+ --------
+
+ Before transform_layout, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_reindex(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vj, vi] * 2.0
+
+ Create the schedule and do transform_layout:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_reindex)
+ block = sch.get_block("B")
+ sch.reindex(block, 0, "read)
+
+ After applying reindex, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_reindex(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ A_reindex = T.alloc_buffer((128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("A_reindex"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ A_reindex[vi, vj] = A[vj, vi]
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A_reindex[vi, vj] * 2.0
+
+ """
+ assert buffer_index_type in ["read", "write"], "Invalid
buffer_index_type"
+ buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
+ return _ffi_api.ScheduleReIndex( # type: ignore # pylint:
disable=no-member
+ self, block, buffer_index, buffer_index_type_enum
+ )
+
########## Schedule: Compute location ##########
@type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 2289899c32..590a0f0025 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV&
block_rv, int write_buff
return CreateRV<BlockRV>(result);
}
+BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int
buffer_index,
+ BufferIndexType buffer_index_type) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index,
buffer_index_type);
+ TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
+ this->state_->DebugVerify();
+ return CreateRV<BlockRV>(result);
+}
+
/******** Schedule: Compute location ********/
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV&
loop_rv,
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 8e83aac2ce..70c0265611 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
+ BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool
preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 50dedf71ff..f4dba69c6b 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const
StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
const String& storage_scope);
+/*!
+ *!
+ * \brief Create a block that read/write a buffer region into a read/write
cache with reindexing.
+ * The layout of the cache will be the same as by the iterators of the block
that reads/writes the
+ * buffer. It requires:
+ * 1) There is only one block who reads/writes the target buffer
+ * 2) There is only one buffer load/store of this buffer in the block
+ * \param self The state of the schedule
+ * \param block_rv The block operates on the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \return The reindex stage block.
+ */
+TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
+ BufferIndexType buffer_index_type);
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index 1bba2ae4fc..c96f88e1f6 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -160,6 +160,121 @@ Block MakeCacheStage(const BufferRegion& cache_region,
CacheStageInfo* info,
return block;
}
+/*!
+ * \brief Create the reindex block and generate the corresponding outer loops.
+ * \details The reindex block is a data copy block between the reindex buffer
(the intermediate
+ * buffer), and the target buffer.
+ If buffer_index_type == kWrite, copy from the reindex buffer to the target
buffer.
+ If buffer_index_type == kRead, copy from the target buffer to the reindex
buffer.
+ The reindex block has the same block iters and the surrounding loops as
the input block.
+ However, if a block iter is not used in the indices of the target buffer
being reindexed, the
+ domain of the block iter, and the corresponding outer loop, will become
constant value one, making
+ it a trivial iter.
+ * \param block The block to be reindexed
+ * \param info The cache info
+ * \param covered The set of block iter vars covered in the buffer access
indices
+ * \param original_indices The original buffer access indices
+ * \param buffer_index The index of the target buffer
+ * \param buffer_index_type The type of buffer index
+ * \return The reindex block.
+ */
+Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
+ const std::unordered_set<Var, ObjectPtrHash,
ObjectPtrEqual>& covered,
+ const Array<PrimExpr>& original_indices, int
buffer_index,
+ BufferIndexType buffer_index_type) {
+ // iters of the reindex block
+ Array<IterVar> new_block_iters;
+ // the substition map from the original block iter to the iters of the
reindex block
+ std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectEqual>
block_var_replace_map;
+ // block access region of reindexed buffer and target buffer
+ Region reindex_region, target_region;
+ // indices to access the reindex buffer and the target buffer
+ Array<PrimExpr> reindex_indices, target_indices;
+
+ // Step 1: Create block iters, access regions of the reindex block, and
accessing indices to the
+ // reindex buffer.
+ for (const IterVar& iter : block->iter_vars) {
+ Var var("v" + std::to_string(new_block_iters.size()));
+ bool used = covered.count(iter->var);
+ new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom :
Range::FromMinExtent(0, 1),
+ /*var=*/var,
+ /*IterVarType=*/kDataPar));
+ if (used) {
+ reindex_indices.push_back(var);
+ reindex_region.push_back(Range::FromMinExtent(var, 1));
+ }
+ block_var_replace_map[iter->var] = var;
+ }
+
+ // Step 2: Replace the original block iters with the new block iters
+ BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite
+ ? block->writes[buffer_index]
+ : block->reads[buffer_index];
+ target_region = Substitute(buffer_region->region, block_var_replace_map);
+ for (const PrimExpr& index : original_indices) {
+ target_indices.push_back(Substitute(index, block_var_replace_map));
+ }
+
+ // Step 3: Create the reindex block
+
+ // The src and the dst region and indices of the data copy
+ Region src_region{nullptr};
+ Region dst_region{nullptr};
+ Array<PrimExpr> src_indices{nullptr};
+ Array<PrimExpr> dst_indices{nullptr};
+
+ if (buffer_index_type == BufferIndexType::kWrite) {
+ src_region = reindex_region;
+ dst_region = target_region;
+ src_indices = reindex_indices;
+ dst_indices = target_indices;
+ } else {
+ src_region = target_region;
+ dst_region = reindex_region;
+ src_indices = target_indices;
+ dst_indices = reindex_indices;
+ }
+
+ // Create the body block
+ Block new_block(
+ /*iter_vars=*/new_block_iters,
+ /*reads=*/
+ {BufferRegion(info->read_buffer, src_region)},
+ /*writes=*/
+ {BufferRegion(info->write_buffer, dst_region)},
+ /*name_hint=*/buffer_region->buffer->name + "_reindex",
+ /*body=*/
+ BufferStore(info->write_buffer, BufferLoad(info->read_buffer,
src_indices), dst_indices));
+
+ // Step 4: Create surrounding loops
+
+ // Create loop vars and bindings for block iters
+ std::vector<Var> loop_vars; // loop variables
+ std::vector<PrimExpr> iter_values; // bindings in block realize
+ for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
+ Var loop_var("ax" + std::to_string(loop_vars.size()));
+ loop_vars.push_back(loop_var);
+ iter_values.push_back(loop_var);
+ }
+
+ // Create the block realize node
+ Stmt body = BlockRealize(/*values=*/iter_values,
+ /*predicate=*/const_true(),
+ /*block=*/new_block);
+
+ // Create the chain of loops
+ for (int i = static_cast<int>(new_block_iters.size()) - 1; i >= 0; --i) {
+ body = For(/*loop_var=*/loop_vars[i],
+ /*min=*/new_block_iters[i]->dom->min,
+ /*extent=*/new_block_iters[i]->dom->extent,
+ /*kind=*/ForKind::kSerial,
+ /*body=*/std::move(body));
+ }
+ // Update cache info, which will be used in the later rewriting.
+ info->cache_stage = std::move(body);
+ return new_block;
+}
+
/*!
* \brief Recalculate the `affine_binding` flag of a specifc block
* \param block_sref The sref to the specific block
@@ -599,6 +714,252 @@ class CacheWriteRewriter : public StmtExprMutator {
bool under_writer_block_{false};
};
+/*!
+ * \brief Create a new buffer by change the shape with block iters to be used
as the reindex buffer
+ * \param buffer The given buffer.
+ * \param block_iters The block iters.
+ * \param covered Set of block iter vars covered by the buffer access indices
+ * \return The new buffer with target shape.
+ */
+Buffer CreateReindexBuffer(const Buffer& buffer, const Array<IterVar>&
block_iters,
+ const std::unordered_set<Var, ObjectPtrHash,
ObjectPtrEqual>& covered) {
+ ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+ ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
+ std::vector<PrimExpr> new_shape;
+ std::vector<PrimExpr> new_strides;
+ for (const auto& iter : block_iters) {
+ if (covered.count(iter->var)) {
+ new_shape.push_back(iter->dom->min + iter->dom->extent);
+ }
+ }
+ new_strides.clear();
+ new_buffer->shape = new_shape;
+ new_buffer->strides = new_strides;
+ new_buffer->data = buffer->data.copy_with_suffix("_reindex");
+ new_buffer->name = buffer->name + "_reindex";
+ return Buffer(new_buffer);
+}
+
+/*! \brief The schedule error that the target is not a leaf block. */
+class NotLeafBlockError : public ScheduleError {
+ public:
+ NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)),
block_(std::move(block)) {}
+ String FastErrorString() const final {
+ return "ScheduleError: The target block is not a leaf block.";
+ }
+
+ String DetailRenderTemplate() const final { return "The target block {0} is
not a leaf block."; }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+ IRModule mod_;
+ Block block_;
+};
+
+/*! \brief The schedule error that the buffer access is invalid for reindex. */
+class InvalidBufferAccessError : public ScheduleError {
+ public:
+ enum class ErrorKind {
+ kNoAccess, // buffer access not found
+ kNonUniqueAccess, // multiple buffer accesses with different indices
+ kOpaqueAccess, // opaque access to the buffer
+ };
+
+ InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind
kind)
+ : mod_(std::move(mod)), buffer_(std::move(buffer)),
block_(std::move(block)), kind_(kind) {}
+ String FastErrorString() const final {
+ return "ScheduleError: The target buffer should be accessed via BufferLoad
or BufferStore. The "
+ "indices should be the same if there are multiple accesses to the
target buffer.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The target buffer " << buffer_->name
+ << " should be accessed in the leaf block {0} via BufferLoad or
BufferStore. The indices "
+ "should be the same if there are multiple accesses to the target
buffer. ";
+ if (kind_ == ErrorKind::kNoAccess) {
+ os << "No buffer accesses found.";
+ } else if (kind_ == ErrorKind::kNonUniqueAccess) {
+ os << "Multiple buffer accesses have non-unique indices.";
+ } else if (kind_ == ErrorKind::kOpaqueAccess) {
+ os << "Opaque buffer accesses found.";
+ }
+ return os.str();
+ }
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+ Block block_;
+ ErrorKind kind_;
+};
+
+/*! \brief Collect the related Load/Store to reindex */
+class ReIndexCollector : public StmtExprVisitor {
+ public:
+ static Array<PrimExpr> Collect(const IRModule& mod, const Buffer& buffer,
const Block& block) {
+ ReIndexCollector collector(mod, buffer, block);
+ collector(block->body);
+ if (!collector.buffer_access_indices_.defined()) {
+ throw InvalidBufferAccessError(mod, buffer, block,
+
InvalidBufferAccessError::ErrorKind::kNoAccess);
+ }
+ return collector.buffer_access_indices_.value();
+ }
+
+ private:
+ explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const
Block& block)
+ : mod_(mod), buffer_(buffer), block_(block) {}
+
+ void VisitExpr_(const BufferLoadNode* load) final {
+ StmtExprVisitor::VisitExpr_(load);
+ if (load->buffer.same_as(buffer_)) {
+ CheckAndUpdateBufferAccessIndices(load->indices);
+ }
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ // no sub-blocks under this block
+ throw NotLeafBlockError(mod_, block_);
+ }
+
+ void VisitStmt_(const BufferStoreNode* store) final {
+ StmtExprVisitor::VisitStmt_(store);
+ if (store->buffer.same_as(buffer_)) {
+ CheckAndUpdateBufferAccessIndices(store->indices);
+ }
+ }
+
+ void CheckAndUpdateBufferAccessIndices(const Array<PrimExpr> indices) {
+ if (!buffer_access_indices_.defined()) {
+ buffer_access_indices_ = indices;
+ return;
+ } else if (!std::equal(buffer_access_indices_.value().begin(),
+ buffer_access_indices_.value().end(),
indices.begin(), indices.end(),
+ ExprDeepEqual())) {
+ throw InvalidBufferAccessError(mod_, buffer_, block_,
+
InvalidBufferAccessError::ErrorKind::kNonUniqueAccess);
+ }
+ }
+
+ void VisitExpr_(const VarNode* var) final {
+ if (var == buffer_->data.get()) {
+ throw InvalidBufferAccessError(mod_, buffer_, block_,
+
InvalidBufferAccessError::ErrorKind::kOpaqueAccess);
+ }
+ }
+ /*! \brief The IR module */
+ IRModule mod_;
+ /*! \brief The buffer to rewrite */
+ Buffer buffer_;
+ /*! \brief The block to visit */
+ Block block_;
+ /*! \brief The indices of buffer acess to rewrite */
+ Optional<Array<PrimExpr>> buffer_access_indices_;
+};
+
+/*! \brief Mutator of ReIndex */
+class ReIndexRewriter : public StmtExprMutator {
+ public:
+ static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref,
CacheStageInfo* info,
+ const std::unordered_set<Var, ObjectPtrHash,
ObjectPtrEqual>& covered) {
+ ReIndexRewriter rewriter(block_sref, info, covered);
+ return rewriter(GetRef<Stmt>(scope_sref->stmt));
+ }
+
+ private:
+ explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info,
+ const std::unordered_set<Var, ObjectPtrHash,
ObjectPtrEqual>& covered)
+ : block_sref_(block_sref), info_(info), covered_(covered) {
+ new_buffer_ = info->alloc;
+ old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer
: info->read_buffer;
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ Block old_stmt = GetRef<Block>(block);
+ if (is_scope_) {
+ is_scope_ = false;
+ Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+ // Insert cache stage into the loop
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+ n->alloc_buffers.push_back(info_->alloc);
+ stmt = Block(n);
+ info_->block_reuse.Set(old_stmt, stmt);
+ return stmt;
+ }
+
+ // Visiting the blokc being reindexed
+ if (block == block_sref_->stmt) {
+ // Collect the updated indices and regions
+ for (const IterVar& iter : block->iter_vars) {
+ if (covered_.count(iter->var)) {
+ indices_.push_back(iter->var);
+ region_.push_back(Range::FromMinExtent(iter->var, 1));
+ }
+ }
+ Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+ // Update block reads/writes to use the intermediate reindex buffer
+ auto writes =
+ ReplaceBufferRegion(block->writes, old_buffer_,
BufferRegion{new_buffer_, region_});
+ auto reads =
+ ReplaceBufferRegion(block->reads, old_buffer_,
BufferRegion{new_buffer_, region_});
+ auto match_buffers = ReplaceBufferRegion(block->match_buffers,
old_buffer_,
+ BufferRegion{new_buffer_,
region_});
+ if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
+ !match_buffers.same_as(block->match_buffers)) {
+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+ n->writes = std::move(writes);
+ n->reads = std::move(reads);
+ n->match_buffers = std::move(match_buffers);
+ stmt = Block(n);
+ }
+ info_->block_reuse.Set(old_stmt, stmt);
+ return stmt;
+ }
+ return old_stmt;
+ }
+
+ template <typename Node>
+ Node VisitBufferAccess(Node node) {
+ if (node->buffer.same_as(old_buffer_)) {
+ auto* n = node.CopyOnWrite();
+ n->buffer = new_buffer_;
+ n->indices = indices_;
+ }
+ return node;
+ }
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ BufferStore buffer_store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ return VisitBufferAccess(std::move(buffer_store));
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad buffer_load =
Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ return VisitBufferAccess(std::move(buffer_load));
+ }
+
+ private:
+ /*! \brief The parent scope of the insertion. */
+ const StmtSRef& block_sref_;
+ /*! \brief The info for inserting reindex stage. */
+ CacheStageInfo* info_;
+ /*! \brief Whether old block var is covered in the indices */
+ const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered_;
+ /*! \brief Whether the current block is scope block */
+ bool is_scope_{true};
+ /*! \brief The buffer to be replaced */
+ Buffer old_buffer_;
+ /*! \brief The reindex buffer */
+ Buffer new_buffer_;
+ /*! \brief The new indices */
+ Array<PrimExpr> indices_;
+ /*! \brief The new region */
+ Region region_;
+};
+
/******** Implementation ********/
StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int
read_buffer_index,
@@ -729,6 +1090,80 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
return result_block_sref;
}
+StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
+ BufferIndexType buffer_index_type) {
+ const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+ Block block = GetRef<Block>(block_ptr);
+ Buffer buffer =
+ GetNthAccessBuffer(self, block, buffer_index, buffer_index_type ==
BufferIndexType::kWrite);
+ StmtSRef scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
+ arith::Analyzer analyzer;
+
+ // Step 1. Collect the original indices and check there's only single
pattern of related
+ // Load/Store and the buffer is not accessed opaquely
+ Array<PrimExpr> original_indices = ReIndexCollector::Collect(self->mod,
buffer, block);
+ // Simplify the indices if possible
+ for (const IterVar& iter : block->iter_vars) {
+ analyzer.Bind(iter->var, iter->dom);
+ }
+ original_indices.MutateByApply(
+ [&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); });
+
+ // Collect block iters appearing in the original_indices
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> covered;
+ for (const PrimExpr& index : original_indices) {
+ PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
+ if (const VarNode* var = obj.as<VarNode>()) {
+ covered.insert(GetRef<Var>(var));
+ }
+ return true;
+ });
+ }
+
+ // Step 2. Creating CacheStageInfo
+ CacheStageInfo info;
+ // Create the corresponding buffer to be read(write), i.e. the result of
reindex read(write)
+ if (buffer_index_type == BufferIndexType::kWrite) {
+ info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
+ info.write_buffer = buffer;
+ info.alloc = info.read_buffer;
+ } else {
+ info.read_buffer = buffer;
+ info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
+ info.alloc = info.write_buffer;
+ }
+
+ // Step 3. Check the block belongs to a chain loop nesting under the scope,
+ // and get the insert location
+ const StmtSRefNode* loop;
+ for (loop = block_sref->parent; loop->parent != scope_sref.get();) {
+ const ForNode* outer = loop->parent->StmtAs<ForNode>();
+ const ForNode* inner = loop->StmtAs<ForNode>();
+ ICHECK(outer != nullptr && inner != nullptr);
+ ICHECK(outer->body.get() == inner);
+ loop = loop->parent;
+ }
+
+ info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index;
+ if (buffer_index_type == BufferIndexType::kWrite) {
+ info.loc_pos++;
+ }
+
+ // Step 4. Making new reindex stage block and rewrite
+ Block reindex_stage =
+ MakeReIndexStage(block, &info, covered, original_indices, buffer_index,
buffer_index_type);
+ Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info,
covered);
+
+ // Step 5. Replacing and updating flags
+ self->Replace(scope_sref, new_scope, info.block_reuse);
+ StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get());
+ BlockInfo& block_info = self->block_info[result_block_sref];
+ block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
+ block_info.region_cover = true;
+ block_info.scope->stage_pipeline = true;
+ return result_block_sref;
+}
+
/******** Instruction Registration ********/
struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
@@ -787,7 +1222,40 @@ struct CacheWriteTraits : public
UnpackedInstTraits<CacheWriteTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
+ static constexpr const char* kName = "ReIndex";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 2;
+ static constexpr size_t kNumDecisions = 0;
+
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer
buffer_index,
+ Integer buffer_index_type) {
+ return sch->ReIndex(block, buffer_index,
+
static_cast<BufferIndexType>(buffer_index_type->value));
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block, Integer
buffer_index,
+ Integer buffer_index_type) {
+ PythonAPICall py("reindex");
+ py.Input("block", block);
+ py.Input("buffer_index", buffer_index);
+ py.Input("buffer_index_type", '"' +
+ std::string(BufferIndexType2Str(
+
static_cast<BufferIndexType>(buffer_index_type->value))) +
+ '"');
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits);
TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits);
+TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index fb884ce77f..3880d0b19e 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead")
.set_body_method<Schedule>(&ScheduleNode::CacheRead);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite")
.set_body_method<Schedule>(&ScheduleNode::CacheWrite);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex")
+ .set_body_typed([](Schedule self, const BlockRV& block_rv, int
buffer_index,
+ int buffer_index_type) {
+ return self->ReIndex(block_rv, buffer_index,
static_cast<BufferIndexType>(buffer_index_type));
+ });
/******** (FFI) Compute location ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt")
.set_body_method<Schedule>(&ScheduleNode::ComputeAt);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 8156480a45..d2f627edfd 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -265,6 +265,18 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV&
block_rv, int write_buffer
return result;
}
+BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type) {
+ BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index,
buffer_index_type);
+
+ static const InstructionKind& kind = InstructionKind::Get("ReIndex");
+ trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(buffer_index),
Integer(buffer_index_type)},
+ /*outputs=*/{result}));
+ return result;
+}
+
/******** Schedule: Compute location ********/
void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV&
loop_rv,
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index d1860be951..ba4a4b99cb 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -73,6 +73,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
const String& storage_scope) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
+ BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+ BufferIndexType buffer_index_type) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool
preserve_unit_loops) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 79802ecd65..67d0f55f20 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -70,6 +70,32 @@ Array<MatchBufferRegion>
ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
return match_buffers;
}
+Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const
Buffer& source_buffer,
+ const BufferRegion& target) {
+ regions.MutateByApply([&source_buffer, &target](const BufferRegion& region)
-> BufferRegion {
+ if (region->buffer.same_as(source_buffer)) {
+ return target;
+ }
+ return region;
+ });
+ return regions;
+}
+
+Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion>
match_buffers,
+ const Buffer& source_buffer,
+ const BufferRegion& target) {
+ match_buffers.MutateByApply([&source_buffer, &target](
+ const MatchBufferRegion& match_buffer) ->
MatchBufferRegion {
+ if (match_buffer->source->buffer.same_as(source_buffer)) {
+ ObjectPtr<MatchBufferRegionNode> n =
make_object<MatchBufferRegionNode>(*match_buffer.get());
+ n->source = target;
+ return MatchBufferRegion(n);
+ }
+ return match_buffer;
+ });
+ return match_buffers;
+}
+
/******** ReplaceBufferMutator ********/
ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer
new_buffer,
Map<Block, Block>* block_sref_reuse)
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 192d44d9e9..908a823c2d 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -73,6 +73,27 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion>
regions, const Buffer& sou
Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers,
const Buffer& source,
const Buffer& target);
+/*!
+ * \brief Replaces the buffer region within the specific sequence of regions
+ * \param regions The regions to be replaced
+ * \param source_buffer The buffer to whose region is to be replaced
+ * \param target The buffer region to be replaced to
+ * \return The new sequence of regions after replacement
+ */
+Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const
Buffer& source_buffer,
+ const BufferRegion& target);
+
+/*!
+ * \brief Replaces the buffer region within the specific sequence of
match_buffers
+ * \param regions The match_buffers to be replaced
+ * \param source_buffer The buffer to whose region is to be replaced
+ * \param target The buffer region to be replaced to
+ * \return The new sequence of match_buffers after replacement
+ */
+Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion>
match_buffers,
+ const Buffer& source_buffer,
+ const BufferRegion& target);
+
/*!
* \brief A helper mutator which recursively replaces the old buffer with the
new buffer and
* collects the block sref reuse information for the following replacement.
diff --git a/tests/python/unittest/test_tir_schedule_reindex.py
b/tests/python/unittest/test_tir_schedule_reindex.py
new file mode 100644
index 0000000000..9b2e37a198
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_reindex.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+
[email protected]_func
+def transpose_elementwise(
+ A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vj, vi] * 2.0
+
+
[email protected]_func
+def transpose_elementwise_reindex_read(
+ A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+ A_reindex = T.alloc_buffer((128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("A_reindex"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ A_reindex[vi, vj] = A[vj, vi]
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A_reindex[vi, vj] * 2.0
+
+
[email protected]_func
+def conv2d_nhwc(
+ Input: T.Buffer[(1, 224, 224, 3), "float32"],
+ Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+ Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+ PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+ for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+ with T.block("PadInput"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 <
227)),
+ Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+ with T.block("conv2d_nhwc"):
+ n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3,
i4, i5, i6])
+ with T.init():
+ Conv2d_nhwc[n, h, w, co] = T.float32(0)
+ Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + (
+ PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co,
64) * 3) + rc)]
+ * Weight[rh, rw, rc, co]
+ )
+
+
[email protected]_func
+def conv2d_nhwc_reindex_weight(
+ var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle
+) -> None:
+ inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32")
+ weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32")
+ conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64],
dtype="float32")
+ PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+ weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32")
+ for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+ with T.block("PadInput"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
+ T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+ PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227,
+ inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3):
+ with T.block("weight_reindex"):
+ v0, v1, v2, v3, v4, v5, v6 = T.axis.remap(
+ "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]
+ )
+ T.reads(weight[v4, v5, v6, v3])
+ T.writes(weight_reindex[v3, v4, v5, v6])
+ weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3]
+ for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+ with T.block("conv2d_nhwc"):
+ n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3,
i4, i5, i6])
+ T.reads(
+ PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc],
+ weight_reindex[co, rh, rw, rc],
+ )
+ T.writes(conv2d_nhwc[n, h, w, co])
+ with T.init():
+ conv2d_nhwc[n, h, w, co] = T.float32(0)
+ conv2d_nhwc[n, h, w, co] = (
+ conv2d_nhwc[n, h, w, co]
+ + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc]
+ * weight_reindex[co, rh, rw, rc]
+ )
+
+
[email protected]_func
+def matmul(
+ A: T.Buffer[(512, 512), "float32"],
+ B: T.Buffer[(512, 512), "float32"],
+ C: T.Buffer[(512, 512), "float32"],
+) -> None:
+ for i0, i1, i2 in T.grid(512, 512, 512):
+ with T.block("matmul"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(C[i, j], A[i, k], B[k, j])
+ T.writes(C[i, j])
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+
[email protected]_func
+def matmul_reindex_write(
+ A: T.Buffer[(512, 512), "float32"],
+ B: T.Buffer[(512, 512), "float32"],
+ C: T.Buffer[(512, 512), "float32"],
+) -> None:
+ C_reindex = T.alloc_buffer([512, 512], dtype="float32")
+ for i0, i1, i2 in T.grid(512, 512, 512):
+ with T.block("matmul"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(C_reindex[i, j], A[i, k], B[k, j])
+ T.writes(C_reindex[i, j])
+ with T.init():
+ C_reindex[i, j] = T.float32(0)
+ C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j]
+ for i0, i1, i2 in T.grid(512, 512, 1):
+ with T.block("C_reindex"):
+ v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(C_reindex[v0, v1])
+ T.writes(C[v0, v1])
+ C[v0, v1] = C_reindex[v0, v1]
+
+
[email protected]_func
+def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128),
"float32"]) -> None:
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vj, vi] + A[vi, vj]
+
+
+def test_reindex_read_basic():
+ sch = tir.Schedule(transpose_elementwise)
+ block = sch.get_block("B")
+ sch.reindex(block, 0, "read")
+ tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read,
sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=transpose_elementwise)
+
+
+def test_conv2d_reindex_read():
+ sch = tir.Schedule(conv2d_nhwc)
+ block = sch.get_block("conv2d_nhwc")
+ sch.reindex(block, 1, "read")
+ tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc)
+
+
+def test_matmul_reindex_write():
+ sch = tir.Schedule(matmul)
+ block = sch.get_block("matmul")
+ sch.reindex(block, 0, "write")
+ tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=matmul)
+
+
+def test_reindex_fail_multiple_read():
+ sch = tir.Schedule(multiple_read)
+ block = sch.get_block("B")
+ with pytest.raises(ScheduleError):
+ sch.reindex(block, 0, "read")
+
+
+if __name__ == "__main__":
+ tvm.testing.main()