This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 12a0f3edcf [TIR] Add schedule primitive ReIndex (#11515)
12a0f3edcf is described below

commit 12a0f3edcf8295288f4aa9ec3dbb6771c3a1a301
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jun 2 14:34:23 2022 -0700

    [TIR] Add schedule primitive ReIndex (#11515)
---
 include/tvm/tir/schedule/schedule.h                |  13 +
 python/tvm/tir/schedule/schedule.py                |  73 ++++
 src/tir/schedule/concrete_schedule.cc              |  10 +
 src/tir/schedule/concrete_schedule.h               |   2 +
 src/tir/schedule/primitive.h                       |  15 +
 src/tir/schedule/primitive/cache_read_write.cc     | 468 +++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |   5 +
 src/tir/schedule/traced_schedule.cc                |  12 +
 src/tir/schedule/traced_schedule.h                 |   2 +
 src/tir/schedule/transform.cc                      |  26 ++
 src/tir/schedule/transform.h                       |  21 +
 tests/python/unittest/test_tir_schedule_reindex.py | 203 +++++++++
 12 files changed, 850 insertions(+)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 48014280a5..68900e107d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
+  /*!
+   * \brief Create a block that read/write a buffer region into a read/write 
cache with reindexing.
+   * The layout of the cache will be the same as by the iterators of the block 
that reads/writes the
+   * buffer. It requires:
+   * 1) There is only one block who reads/writes the target buffer
+   * 2) There is only one buffer load/store of this buffer in the block
+   * \param block_rv The block operates on the target buffer.
+   * \param buffer_index The index of the buffer in block's read or write 
region.
+   * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+   * \return The reindex stage block.
+   */
+  virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+                          BufferIndexType buffer_index_type) = 0;
   /******** Schedule: Compute location ********/
   /*!
    * \brief Move a producer block under the specific loop, and regenerate the
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index f86228848b..4179088aa5 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1056,6 +1056,79 @@ class Schedule(Object):
             self, block, write_buffer_index, storage_scope
         )
 
+    @type_checked
+    def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: 
str) -> BlockRV:
+        """Create a block that read/write a buffer region into a read/write 
cache with reindexing.
+        The layout of the cache will be the same as by the iterators of the 
block that reads/writes
+        the buffer. It requires:
+        1) There is only one block who reads/writes the target buffer
+        2) There is only one buffer load/store of this buffer in the block
+
+        Parameters
+        ----------
+        block: BlockRV
+            The block that accesses the target buffer
+        buffer_index: int
+            The index of the buffer in block's read or write region
+        buffer_index_type : str
+            Type of the buffer index, "read" or "write"
+
+        Returns
+        -------
+        reindex_block : BlockRV
+            The block of the reindex stage
+
+        Examples
+        --------
+
+        Before transform_layout, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_reindex(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vj, vi] * 2.0
+
+        Create the schedule and do transform_layout:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_reindex)
+            block = sch.get_block("B")
+            sch.reindex(block, 0, "read)
+
+        After applying reindex, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_reindex(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                A_reindex = T.alloc_buffer((128, 128), "float32")
+                for i, j in T.grid(128, 128):
+                    with T.block("A_reindex"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        A_reindex[vi, vj] = A[vj, vi]
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A_reindex[vi, vj] * 2.0
+
+        """
+        assert buffer_index_type in ["read", "write"], "Invalid 
buffer_index_type"
+        buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
+        return _ffi_api.ScheduleReIndex(  # type: ignore # pylint: 
disable=no-member
+            self, block, buffer_index, buffer_index_type_enum
+        )
+
     ########## Schedule: Compute location ##########
 
     @type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 2289899c32..590a0f0025 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& 
block_rv, int write_buff
   return CreateRV<BlockRV>(result);
 }
 
+BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int 
buffer_index,
+                                      BufferIndexType buffer_index_type) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, 
buffer_index_type);
+  TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
+  this->state_->DebugVerify();
+  return CreateRV<BlockRV>(result);
+}
+
 /******** Schedule: Compute location ********/
 
 void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& 
loop_rv,
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 8e83aac2ce..70c0265611 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode {
                     const String& storage_scope) override;
   BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                      const String& storage_scope) override;
+  BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+                  BufferIndexType buffer_index_type) override;
   /******** Schedule: Compute location ********/
   void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool 
preserve_unit_loops) override;
   void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 50dedf71ff..f4dba69c6b 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const 
StmtSRef& block_sref, int r
  */
 TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, 
int write_buffer_index,
                             const String& storage_scope);
+/*!
+ *!
+ * \brief Create a block that read/write a buffer region into a read/write 
cache with reindexing.
+ * The layout of the cache will be the same as by the iterators of the block 
that reads/writes the
+ * buffer. It requires:
+ * 1) There is only one block who reads/writes the target buffer
+ * 2) There is only one buffer load/store of this buffer in the block
+ * \param self The state of the schedule
+ * \param block_rv The block operates on the target buffer.
+ * \param buffer_index The index of the buffer in block's read or write region.
+ * \param buffer_index_type The type of the buffer index, kRead or kWrite.
+ * \return The reindex stage block.
+ */
+TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
+                         BufferIndexType buffer_index_type);
 /******** Schedule: Compute location ********/
 /*!
  * \brief Move a producer block under the specific loop, and regenerate the
diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index 1bba2ae4fc..c96f88e1f6 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -160,6 +160,121 @@ Block MakeCacheStage(const BufferRegion& cache_region, 
CacheStageInfo* info,
   return block;
 }
 
+/*!
+ * \brief Create the reindex block and generate the corresponding outer loops.
+ * \details The reindex block is a data copy block between the reindex buffer 
(the intermediate
+ * buffer), and the target buffer.
+    If buffer_index_type == kWrite, copy from the reindex buffer to the target 
buffer.
+    If buffer_index_type == kRead, copy from the target buffer to the reindex 
buffer.
+    The reindex block has the same block iters and the surrounding loops as 
the input block.
+ However, if a block iter is not used in the indices of the target buffer 
being reindexed, the
+ domain of the block iter, and the corresponding outer loop, will become 
constant value one, making
+ it a trivial iter.
+ * \param block The block to be reindexed
+ * \param info The cache info
+ * \param covered The set of block iter vars covered in the buffer access 
indices
+ * \param original_indices The original buffer access indices
+ * \param buffer_index The index of the target buffer
+ * \param buffer_index_type The type of buffer index
+ * \return The reindex block.
+ */
+Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
+                       const std::unordered_set<Var, ObjectPtrHash, 
ObjectPtrEqual>& covered,
+                       const Array<PrimExpr>& original_indices, int 
buffer_index,
+                       BufferIndexType buffer_index_type) {
+  // iters of the reindex block
+  Array<IterVar> new_block_iters;
+  // the substition map from the original block iter to the iters of the 
reindex block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectEqual> 
block_var_replace_map;
+  // block access region of reindexed buffer and target buffer
+  Region reindex_region, target_region;
+  // indices to access the reindex buffer and the target buffer
+  Array<PrimExpr> reindex_indices, target_indices;
+
+  // Step 1: Create block iters, access regions of the reindex block, and 
accessing indices to the
+  // reindex buffer.
+  for (const IterVar& iter : block->iter_vars) {
+    Var var("v" + std::to_string(new_block_iters.size()));
+    bool used = covered.count(iter->var);
+    new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : 
Range::FromMinExtent(0, 1),
+                                      /*var=*/var,
+                                      /*IterVarType=*/kDataPar));
+    if (used) {
+      reindex_indices.push_back(var);
+      reindex_region.push_back(Range::FromMinExtent(var, 1));
+    }
+    block_var_replace_map[iter->var] = var;
+  }
+
+  // Step 2: Replace the original block iters with the new block iters
+  BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite
+                                   ? block->writes[buffer_index]
+                                   : block->reads[buffer_index];
+  target_region = Substitute(buffer_region->region, block_var_replace_map);
+  for (const PrimExpr& index : original_indices) {
+    target_indices.push_back(Substitute(index, block_var_replace_map));
+  }
+
+  // Step 3: Create the reindex block
+
+  // The src and the dst region and indices of the data copy
+  Region src_region{nullptr};
+  Region dst_region{nullptr};
+  Array<PrimExpr> src_indices{nullptr};
+  Array<PrimExpr> dst_indices{nullptr};
+
+  if (buffer_index_type == BufferIndexType::kWrite) {
+    src_region = reindex_region;
+    dst_region = target_region;
+    src_indices = reindex_indices;
+    dst_indices = target_indices;
+  } else {
+    src_region = target_region;
+    dst_region = reindex_region;
+    src_indices = target_indices;
+    dst_indices = reindex_indices;
+  }
+
+  // Create the body block
+  Block new_block(
+      /*iter_vars=*/new_block_iters,
+      /*reads=*/
+      {BufferRegion(info->read_buffer, src_region)},
+      /*writes=*/
+      {BufferRegion(info->write_buffer, dst_region)},
+      /*name_hint=*/buffer_region->buffer->name + "_reindex",
+      /*body=*/
+      BufferStore(info->write_buffer, BufferLoad(info->read_buffer, 
src_indices), dst_indices));
+
+  // Step 4: Create surrounding loops
+
+  // Create loop vars and bindings for block iters
+  std::vector<Var> loop_vars;         // loop variables
+  std::vector<PrimExpr> iter_values;  // bindings in block realize
+  for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
+    Var loop_var("ax" + std::to_string(loop_vars.size()));
+    loop_vars.push_back(loop_var);
+    iter_values.push_back(loop_var);
+  }
+
+  // Create the block realize node
+  Stmt body = BlockRealize(/*values=*/iter_values,
+                           /*predicate=*/const_true(),
+                           /*block=*/new_block);
+
+  // Create the chain of loops
+  for (int i = static_cast<int>(new_block_iters.size()) - 1; i >= 0; --i) {
+    body = For(/*loop_var=*/loop_vars[i],
+               /*min=*/new_block_iters[i]->dom->min,
+               /*extent=*/new_block_iters[i]->dom->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/std::move(body));
+  }
+  // Update cache info, which will be used in the later rewriting.
+  info->cache_stage = std::move(body);
+  return new_block;
+}
+
 /*!
  * \brief Recalculate the `affine_binding` flag of a specifc block
  * \param block_sref The sref to the specific block
@@ -599,6 +714,252 @@ class CacheWriteRewriter : public StmtExprMutator {
   bool under_writer_block_{false};
 };
 
+/*!
+ * \brief Create a new buffer by change the shape with block iters to be used 
as the reindex buffer
+ * \param buffer The given buffer.
+ * \param block_iters The block iters.
+ * \param covered Set of block iter vars covered by the buffer access indices
+ * \return The new buffer with target shape.
+ */
+Buffer CreateReindexBuffer(const Buffer& buffer, const Array<IterVar>& 
block_iters,
+                           const std::unordered_set<Var, ObjectPtrHash, 
ObjectPtrEqual>& covered) {
+  ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+  ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
+  std::vector<PrimExpr> new_shape;
+  std::vector<PrimExpr> new_strides;
+  for (const auto& iter : block_iters) {
+    if (covered.count(iter->var)) {
+      new_shape.push_back(iter->dom->min + iter->dom->extent);
+    }
+  }
+  new_strides.clear();
+  new_buffer->shape = new_shape;
+  new_buffer->strides = new_strides;
+  new_buffer->data = buffer->data.copy_with_suffix("_reindex");
+  new_buffer->name = buffer->name + "_reindex";
+  return Buffer(new_buffer);
+}
+
+/*! \brief The schedule error that the target is not a leaf block. */
+class NotLeafBlockError : public ScheduleError {
+ public:
+  NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), 
block_(std::move(block)) {}
+  String FastErrorString() const final {
+    return "ScheduleError: The target block is not a leaf block.";
+  }
+
+  String DetailRenderTemplate() const final { return "The target block {0} is 
not a leaf block."; }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  IRModule mod_;
+  Block block_;
+};
+
+/*! \brief The schedule error that the buffer access is invalid for reindex. */
+class InvalidBufferAccessError : public ScheduleError {
+ public:
+  enum class ErrorKind {
+    kNoAccess,         // buffer access not found
+    kNonUniqueAccess,  // multiple buffer accesses with different indices
+    kOpaqueAccess,     // opaque access to the buffer
+  };
+
+  InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind 
kind)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)), 
block_(std::move(block)), kind_(kind) {}
+  String FastErrorString() const final {
+    return "ScheduleError: The target buffer should be accessed via BufferLoad 
or BufferStore. The "
+           "indices should be the same if there are multiple accesses to the 
target buffer.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The target buffer " << buffer_->name
+       << " should be accessed in the leaf block {0} via BufferLoad or 
BufferStore. The indices "
+          "should be the same if there are multiple accesses to the target 
buffer. ";
+    if (kind_ == ErrorKind::kNoAccess) {
+      os << "No buffer accesses found.";
+    } else if (kind_ == ErrorKind::kNonUniqueAccess) {
+      os << "Multiple buffer accesses have non-unique indices.";
+    } else if (kind_ == ErrorKind::kOpaqueAccess) {
+      os << "Opaque buffer accesses found.";
+    }
+    return os.str();
+  }
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block block_;
+  ErrorKind kind_;
+};
+
+/*! \brief Collect the related Load/Store to reindex */
+class ReIndexCollector : public StmtExprVisitor {
+ public:
+  static Array<PrimExpr> Collect(const IRModule& mod, const Buffer& buffer, 
const Block& block) {
+    ReIndexCollector collector(mod, buffer, block);
+    collector(block->body);
+    if (!collector.buffer_access_indices_.defined()) {
+      throw InvalidBufferAccessError(mod, buffer, block,
+                                     
InvalidBufferAccessError::ErrorKind::kNoAccess);
+    }
+    return collector.buffer_access_indices_.value();
+  }
+
+ private:
+  explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const 
Block& block)
+      : mod_(mod), buffer_(buffer), block_(block) {}
+
+  void VisitExpr_(const BufferLoadNode* load) final {
+    StmtExprVisitor::VisitExpr_(load);
+    if (load->buffer.same_as(buffer_)) {
+      CheckAndUpdateBufferAccessIndices(load->indices);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    // no sub-blocks under this block
+    throw NotLeafBlockError(mod_, block_);
+  }
+
+  void VisitStmt_(const BufferStoreNode* store) final {
+    StmtExprVisitor::VisitStmt_(store);
+    if (store->buffer.same_as(buffer_)) {
+      CheckAndUpdateBufferAccessIndices(store->indices);
+    }
+  }
+
+  void CheckAndUpdateBufferAccessIndices(const Array<PrimExpr> indices) {
+    if (!buffer_access_indices_.defined()) {
+      buffer_access_indices_ = indices;
+      return;
+    } else if (!std::equal(buffer_access_indices_.value().begin(),
+                           buffer_access_indices_.value().end(), 
indices.begin(), indices.end(),
+                           ExprDeepEqual())) {
+      throw InvalidBufferAccessError(mod_, buffer_, block_,
+                                     
InvalidBufferAccessError::ErrorKind::kNonUniqueAccess);
+    }
+  }
+
+  void VisitExpr_(const VarNode* var) final {
+    if (var == buffer_->data.get()) {
+      throw InvalidBufferAccessError(mod_, buffer_, block_,
+                                     
InvalidBufferAccessError::ErrorKind::kOpaqueAccess);
+    }
+  }
+  /*! \brief The IR module */
+  IRModule mod_;
+  /*! \brief The buffer to rewrite */
+  Buffer buffer_;
+  /*! \brief The block to visit */
+  Block block_;
+  /*! \brief The indices of buffer acess to rewrite */
+  Optional<Array<PrimExpr>> buffer_access_indices_;
+};
+
+/*! \brief Mutator of ReIndex */
+class ReIndexRewriter : public StmtExprMutator {
+ public:
+  static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, 
CacheStageInfo* info,
+                      const std::unordered_set<Var, ObjectPtrHash, 
ObjectPtrEqual>& covered) {
+    ReIndexRewriter rewriter(block_sref, info, covered);
+    return rewriter(GetRef<Stmt>(scope_sref->stmt));
+  }
+
+ private:
+  explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info,
+                           const std::unordered_set<Var, ObjectPtrHash, 
ObjectPtrEqual>& covered)
+      : block_sref_(block_sref), info_(info), covered_(covered) {
+    new_buffer_ = info->alloc;
+    old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer 
: info->read_buffer;
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Block old_stmt = GetRef<Block>(block);
+    if (is_scope_) {
+      is_scope_ = false;
+      Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+      // Insert cache stage into the loop
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+      n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+      n->alloc_buffers.push_back(info_->alloc);
+      stmt = Block(n);
+      info_->block_reuse.Set(old_stmt, stmt);
+      return stmt;
+    }
+
+    // Visiting the blokc being reindexed
+    if (block == block_sref_->stmt) {
+      // Collect the updated indices and regions
+      for (const IterVar& iter : block->iter_vars) {
+        if (covered_.count(iter->var)) {
+          indices_.push_back(iter->var);
+          region_.push_back(Range::FromMinExtent(iter->var, 1));
+        }
+      }
+      Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+      // Update block reads/writes to use the intermediate reindex buffer
+      auto writes =
+          ReplaceBufferRegion(block->writes, old_buffer_, 
BufferRegion{new_buffer_, region_});
+      auto reads =
+          ReplaceBufferRegion(block->reads, old_buffer_, 
BufferRegion{new_buffer_, region_});
+      auto match_buffers = ReplaceBufferRegion(block->match_buffers, 
old_buffer_,
+                                               BufferRegion{new_buffer_, 
region_});
+      if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
+          !match_buffers.same_as(block->match_buffers)) {
+        ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+        n->writes = std::move(writes);
+        n->reads = std::move(reads);
+        n->match_buffers = std::move(match_buffers);
+        stmt = Block(n);
+      }
+      info_->block_reuse.Set(old_stmt, stmt);
+      return stmt;
+    }
+    return old_stmt;
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    if (node->buffer.same_as(old_buffer_)) {
+      auto* n = node.CopyOnWrite();
+      n->buffer = new_buffer_;
+      n->indices = indices_;
+    }
+    return node;
+  }
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore buffer_store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(buffer_store));
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad buffer_load = 
Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(buffer_load));
+  }
+
+ private:
+  /*! \brief The parent scope of the insertion. */
+  const StmtSRef& block_sref_;
+  /*! \brief The info for inserting reindex stage. */
+  CacheStageInfo* info_;
+  /*! \brief Whether old block var is covered in the indices */
+  const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered_;
+  /*! \brief Whether the current block is scope block */
+  bool is_scope_{true};
+  /*! \brief The  buffer to be replaced */
+  Buffer old_buffer_;
+  /*! \brief The reindex buffer */
+  Buffer new_buffer_;
+  /*! \brief The new indices */
+  Array<PrimExpr> indices_;
+  /*! \brief The new region */
+  Region region_;
+};
+
 /******** Implementation ********/
 
 StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int 
read_buffer_index,
@@ -729,6 +1090,80 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& 
block_sref, int write_bu
   return result_block_sref;
 }
 
+StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
+                 BufferIndexType buffer_index_type) {
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  Block block = GetRef<Block>(block_ptr);
+  Buffer buffer =
+      GetNthAccessBuffer(self, block, buffer_index, buffer_index_type == 
BufferIndexType::kWrite);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/true);
+  arith::Analyzer analyzer;
+
+  // Step 1. Collect the original indices and check there's only single 
pattern of related
+  // Load/Store and the buffer is not accessed opaquely
+  Array<PrimExpr> original_indices = ReIndexCollector::Collect(self->mod, 
buffer, block);
+  // Simplify the indices if possible
+  for (const IterVar& iter : block->iter_vars) {
+    analyzer.Bind(iter->var, iter->dom);
+  }
+  original_indices.MutateByApply(
+      [&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); });
+
+  // Collect block iters appearing in the original_indices
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> covered;
+  for (const PrimExpr& index : original_indices) {
+    PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
+      if (const VarNode* var = obj.as<VarNode>()) {
+        covered.insert(GetRef<Var>(var));
+      }
+      return true;
+    });
+  }
+
+  // Step 2. Creating CacheStageInfo
+  CacheStageInfo info;
+  // Create the corresponding buffer to be read(write), i.e. the result of 
reindex read(write)
+  if (buffer_index_type == BufferIndexType::kWrite) {
+    info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
+    info.write_buffer = buffer;
+    info.alloc = info.read_buffer;
+  } else {
+    info.read_buffer = buffer;
+    info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
+    info.alloc = info.write_buffer;
+  }
+
+  // Step 3. Check the block belongs to a chain loop nesting under the scope,
+  //         and get the insert location
+  const StmtSRefNode* loop;
+  for (loop = block_sref->parent; loop->parent != scope_sref.get();) {
+    const ForNode* outer = loop->parent->StmtAs<ForNode>();
+    const ForNode* inner = loop->StmtAs<ForNode>();
+    ICHECK(outer != nullptr && inner != nullptr);
+    ICHECK(outer->body.get() == inner);
+    loop = loop->parent;
+  }
+
+  info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index;
+  if (buffer_index_type == BufferIndexType::kWrite) {
+    info.loc_pos++;
+  }
+
+  // Step 4. Making new reindex stage block and rewrite
+  Block reindex_stage =
+      MakeReIndexStage(block, &info, covered, original_indices, buffer_index, 
buffer_index_type);
+  Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, 
covered);
+
+  // Step 5. Replacing and updating flags
+  self->Replace(scope_sref, new_scope, info.block_reuse);
+  StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get());
+  BlockInfo& block_info = self->block_info[result_block_sref];
+  block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
+  block_info.region_cover = true;
+  block_info.scope->stage_pipeline = true;
+  return result_block_sref;
+}
+
 /******** Instruction Registration ********/
 
 struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
@@ -787,7 +1222,40 @@ struct CacheWriteTraits : public 
UnpackedInstTraits<CacheWriteTraits> {
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
+  static constexpr const char* kName = "ReIndex";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 2;
+  static constexpr size_t kNumDecisions = 0;
+
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer 
buffer_index,
+                                         Integer buffer_index_type) {
+    return sch->ReIndex(block, buffer_index,
+                        
static_cast<BufferIndexType>(buffer_index_type->value));
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block, Integer 
buffer_index,
+                                 Integer buffer_index_type) {
+    PythonAPICall py("reindex");
+    py.Input("block", block);
+    py.Input("buffer_index", buffer_index);
+    py.Input("buffer_index_type", '"' +
+                                      std::string(BufferIndexType2Str(
+                                          
static_cast<BufferIndexType>(buffer_index_type->value))) +
+                                      '"');
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits);
 TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits);
+TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits);
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index fb884ce77f..3880d0b19e 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead")
     .set_body_method<Schedule>(&ScheduleNode::CacheRead);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite")
     .set_body_method<Schedule>(&ScheduleNode::CacheWrite);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex")
+    .set_body_typed([](Schedule self, const BlockRV& block_rv, int 
buffer_index,
+                       int buffer_index_type) {
+      return self->ReIndex(block_rv, buffer_index, 
static_cast<BufferIndexType>(buffer_index_type));
+    });
 /******** (FFI) Compute location ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt")
     .set_body_method<Schedule>(&ScheduleNode::ComputeAt);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 8156480a45..d2f627edfd 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -265,6 +265,18 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& 
block_rv, int write_buffer
   return result;
 }
 
+BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
+                                    BufferIndexType buffer_index_type) {
+  BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, 
buffer_index_type);
+
+  static const InstructionKind& kind = InstructionKind::Get("ReIndex");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{Integer(buffer_index), 
Integer(buffer_index_type)},
+                                      /*outputs=*/{result}));
+  return result;
+}
+
 /******** Schedule: Compute location ********/
 
 void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& 
loop_rv,
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index d1860be951..ba4a4b99cb 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -73,6 +73,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
                     const String& storage_scope) final;
   BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                      const String& storage_scope) final;
+  BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
+                  BufferIndexType buffer_index_type) final;
   /******** Schedule: Compute location ********/
   void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool 
preserve_unit_loops) final;
   void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 79802ecd65..67d0f55f20 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -70,6 +70,32 @@ Array<MatchBufferRegion> 
ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
   return match_buffers;
 }
 
+Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const 
Buffer& source_buffer,
+                                        const BufferRegion& target) {
+  regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) 
-> BufferRegion {
+    if (region->buffer.same_as(source_buffer)) {
+      return target;
+    }
+    return region;
+  });
+  return regions;
+}
+
+Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> 
match_buffers,
+                                             const Buffer& source_buffer,
+                                             const BufferRegion& target) {
+  match_buffers.MutateByApply([&source_buffer, &target](
+                                  const MatchBufferRegion& match_buffer) -> 
MatchBufferRegion {
+    if (match_buffer->source->buffer.same_as(source_buffer)) {
+      ObjectPtr<MatchBufferRegionNode> n = 
make_object<MatchBufferRegionNode>(*match_buffer.get());
+      n->source = target;
+      return MatchBufferRegion(n);
+    }
+    return match_buffer;
+  });
+  return match_buffers;
+}
+
 /******** ReplaceBufferMutator ********/
 ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer 
new_buffer,
                                            Map<Block, Block>* block_sref_reuse)
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 192d44d9e9..908a823c2d 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -73,6 +73,27 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> 
regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, 
const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief Replaces the buffer region within the specific sequence of regions
+ * \param regions The regions to be replaced
+ * \param source_buffer The buffer to whose region is to be replaced
+ * \param target The buffer region to be replaced to
+ * \return The new sequence of regions after replacement
+ */
+Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const 
Buffer& source_buffer,
+                                        const BufferRegion& target);
+
+/*!
+ * \brief Replaces the buffer region within the specific sequence of 
match_buffers
+ * \param regions The match_buffers to be replaced
+ * \param source_buffer The buffer to whose region is to be replaced
+ * \param target The buffer region to be replaced to
+ * \return The new sequence of match_buffers after replacement
+ */
+Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> 
match_buffers,
+                                             const Buffer& source_buffer,
+                                             const BufferRegion& target);
+
 /*!
  * \brief A helper mutator which recursively replaces the old buffer with the 
new buffer and
  * collects the block sref reuse information for the following replacement.
diff --git a/tests/python/unittest/test_tir_schedule_reindex.py 
b/tests/python/unittest/test_tir_schedule_reindex.py
new file mode 100644
index 0000000000..9b2e37a198
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_reindex.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import ScheduleError
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+
[email protected]_func
+def transpose_elementwise(
+    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vj, vi] * 2.0
+
+
[email protected]_func
+def transpose_elementwise_reindex_read(
+    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+    A_reindex = T.alloc_buffer((128, 128), "float32")
+    for i, j in T.grid(128, 128):
+        with T.block("A_reindex"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            A_reindex[vi, vj] = A[vj, vi]
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A_reindex[vi, vj] * 2.0
+
+
[email protected]_func
+def conv2d_nhwc(
+    Input: T.Buffer[(1, 224, 224, 3), "float32"],
+    Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+    Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+        with T.block("PadInput"):
+            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 
227)),
+                Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1],
+                T.float32(0),
+                dtype="float32",
+            )
+    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+        with T.block("conv2d_nhwc"):
+            n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, 
i4, i5, i6])
+            with T.init():
+                Conv2d_nhwc[n, h, w, co] = T.float32(0)
+            Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + (
+                PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co, 
64) * 3) + rc)]
+                * Weight[rh, rw, rc, co]
+            )
+
+
[email protected]_func
+def conv2d_nhwc_reindex_weight(
+    var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle
+) -> None:
+    inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32")
+    weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32")
+    conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64], 
dtype="float32")
+    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+    weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32")
+    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+        with T.block("PadInput"):
+            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
+            T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227,
+                inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
+                T.float32(0),
+                dtype="float32",
+            )
+    for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3):
+        with T.block("weight_reindex"):
+            v0, v1, v2, v3, v4, v5, v6 = T.axis.remap(
+                "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]
+            )
+            T.reads(weight[v4, v5, v6, v3])
+            T.writes(weight_reindex[v3, v4, v5, v6])
+            weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3]
+    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+        with T.block("conv2d_nhwc"):
+            n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, 
i4, i5, i6])
+            T.reads(
+                PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc],
+                weight_reindex[co, rh, rw, rc],
+            )
+            T.writes(conv2d_nhwc[n, h, w, co])
+            with T.init():
+                conv2d_nhwc[n, h, w, co] = T.float32(0)
+            conv2d_nhwc[n, h, w, co] = (
+                conv2d_nhwc[n, h, w, co]
+                + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc]
+                * weight_reindex[co, rh, rw, rc]
+            )
+
+
[email protected]_func
+def matmul(
+    A: T.Buffer[(512, 512), "float32"],
+    B: T.Buffer[(512, 512), "float32"],
+    C: T.Buffer[(512, 512), "float32"],
+) -> None:
+    for i0, i1, i2 in T.grid(512, 512, 512):
+        with T.block("matmul"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(C[i, j], A[i, k], B[k, j])
+            T.writes(C[i, j])
+            with T.init():
+                C[i, j] = T.float32(0)
+            C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+
[email protected]_func
+def matmul_reindex_write(
+    A: T.Buffer[(512, 512), "float32"],
+    B: T.Buffer[(512, 512), "float32"],
+    C: T.Buffer[(512, 512), "float32"],
+) -> None:
+    C_reindex = T.alloc_buffer([512, 512], dtype="float32")
+    for i0, i1, i2 in T.grid(512, 512, 512):
+        with T.block("matmul"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads(C_reindex[i, j], A[i, k], B[k, j])
+            T.writes(C_reindex[i, j])
+            with T.init():
+                C_reindex[i, j] = T.float32(0)
+            C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j]
+    for i0, i1, i2 in T.grid(512, 512, 1):
+        with T.block("C_reindex"):
+            v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2])
+            T.reads(C_reindex[v0, v1])
+            T.writes(C[v0, v1])
+            C[v0, v1] = C_reindex[v0, v1]
+
+
[email protected]_func
+def multiple_read(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"]) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vj, vi] + A[vi, vj]
+
+
+def test_reindex_read_basic():
+    sch = tir.Schedule(transpose_elementwise)
+    block = sch.get_block("B")
+    sch.reindex(block, 0, "read")
+    tvm.ir.assert_structural_equal(transpose_elementwise_reindex_read, 
sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=transpose_elementwise)
+
+
+def test_conv2d_reindex_read():
+    sch = tir.Schedule(conv2d_nhwc)
+    block = sch.get_block("conv2d_nhwc")
+    sch.reindex(block, 1, "read")
+    tvm.ir.assert_structural_equal(conv2d_nhwc_reindex_weight, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc)
+
+
+def test_matmul_reindex_write():
+    sch = tir.Schedule(matmul)
+    block = sch.get_block("matmul")
+    sch.reindex(block, 0, "write")
+    tvm.ir.assert_structural_equal(matmul_reindex_write, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=matmul)
+
+
+def test_reindex_fail_multiple_read():
+    sch = tir.Schedule(multiple_read)
+    block = sch.get_block("B")
+    with pytest.raises(ScheduleError):
+        sch.reindex(block, 0, "read")
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to