This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b01ab9e  [TensorIR][M2a] CacheRead/Write (#8863)
b01ab9e is described below

commit b01ab9e81e6a9605e6d2dce5b0c81ce551c1839b
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Sep 1 03:59:53 2021 +0800

    [TensorIR][M2a] CacheRead/Write (#8863)
    
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
---
 include/tvm/tir/schedule/schedule.h                |  22 +
 include/tvm/tir/schedule/state.h                   |   5 +
 python/tvm/tir/schedule/schedule.py                | 135 ++++
 src/tir/schedule/analysis.h                        |  21 +-
 src/tir/schedule/analysis/analysis.cc              |  50 +-
 src/tir/schedule/concrete_schedule.cc              |  21 +
 src/tir/schedule/concrete_schedule.h               |   4 +
 src/tir/schedule/primitive.h                       |  24 +
 src/tir/schedule/primitive/block_annotate.cc       |   4 +-
 src/tir/schedule/primitive/cache_read_write.cc     | 781 +++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |   4 +
 src/tir/schedule/state.cc                          |  18 +
 src/tir/schedule/traced_schedule.cc                |  23 +
 src/tir/schedule/traced_schedule.h                 |   4 +
 src/tir/schedule/transform.cc                      |  40 ++
 src/tir/schedule/transform.h                       |  29 +
 src/tir/schedule/utils.h                           |   1 +
 .../unittest/test_tir_schedule_cache_read_write.py | 677 ++++++++++++++++++
 18 files changed, 1840 insertions(+), 23 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 79fed09..33776cb 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void Unroll(const LoopRV& loop_rv) = 0;
   /******** Schedule: Insert cache stages ********/
+  /*!
+   * \brief Create a block that reads a buffer region into a read cache. It 
requires:
+   * 1) There is at most one block who writes the buffer in the scope.
+   * 2) The scope block have stage-pipeline property.
+   * \param block_rv The consumer block of the target buffer.
+   * \param read_buffer_index The index of the buffer in block's read region.
+   * \param storage_scope The target storage scope.
+   * \return The cache stage block.
+   */
+  virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+                            const String& storage_scope) = 0;
+  /*!
+   * \brief Create a block that writes a buffer region into a write cache. It 
requires:
+   * 1) There is only one block who writes the target buffer.
+   * 2) The scope block have stage-pipeline property.
+   * \param block_rv The producer of the buffer
+   * \param write_buffer_index The index of the buffer in block's write region
+   * \param storage_scope The target storage scope
+   * \return The cache stage block.
+   */
+  virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+                             const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
   /*!
    * \brief Inline a block into its consumer(s). It requires:
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 7cd1b00..35299a3 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -129,6 +129,11 @@ class ScheduleStateNode : public Object {
   TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
                        const Map<Block, Block>& block_sref_reuse);
   /*!
+   * \brief Recalculate the `affine_binding` flag of the scope block info.
+   * \param scope_sref The sref to the interested scope block.
+   */
+  TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
+  /*!
    * \brief Trigger the verification according to the `debug_mask` bitmask.
    * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the 
sref tree.
    * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of 
`affine_binding`,
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 9433d01..ac09bdb 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -790,6 +790,141 @@ class Schedule(Object):
 
     ########## Schedule: Insert cache stages ##########
 
+    def cache_read(self, block: BlockRV, read_buffer_index: int, 
storage_scope: str) -> BlockRV:
+        """Create a block that reads a buffer region into a read cache. It 
requires:
+
+        1) There is at most one block who write the buffer in the scope.
+
+        2) The scope block have stage-pipeline property.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The consumer block of the target buffer.
+
+        read_buffer_index: int
+            The index of the buffer in block's read region.
+
+        storage_scope: str
+            The target storage scope.
+
+        Returns
+        -------
+        cached_block : BlockRV
+            The block of the cache stage
+
+        Examples
+        --------
+        Before cache_read, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_cache_read(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and cache_read:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_cache_read)
+            block_b = sch.get_block("B")
+            sch.cache_read(block_b, 0, "local")
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying cache_read, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_cache_read(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                A_local = tir.alloc_buffer((128, 128), scope="local")
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "A_local") as [vi, vj]:
+                        A_local[vi, vj] = A[vi, vj]
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        B[vi, vj] = A_local[vi, vj] * 2.0
+
+        """
+        return _ffi_api.ScheduleCacheRead(  # type: ignore # pylint: 
disable=no-member
+            self, block, read_buffer_index, storage_scope
+        )
+
+    def cache_write(self, block: BlockRV, write_buffer_index: int, 
storage_scope: str) -> BlockRV:
+        """Create a block that reads a buffer region into a write cache. It 
requires:
+
+        1) There is only one block who write the buffer in the scope.
+
+        2) The scope block have stage-pipeline property.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The producer block of the target buffer.
+
+        write_buffer_index: int
+            The index of the buffer in block's write region.
+
+        storage_scope: str
+            The target storage scope.
+
+
+        Returns
+        -------
+        cached_block : BlockRV
+            The block of the cache stage
+
+        Examples
+        --------
+        Before cache_write, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_cache_write(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and cache_write:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_cache_write)
+            block_b = sch.get_block("B")
+            sch.cache_write(block_b, 0, "local")
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying cache_write, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_cache_write(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                B_local = tir.alloc_buffer((128, 128), scope="local")
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "A_local") as [vi, vj]:
+                        B_local[vi, vj] = A[vi, vj] * 2.0
+                for i, j in tir.grid(128, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        B[vi, vj] = B_local[vi, vj]
+
+        """
+        return _ffi_api.ScheduleCacheWrite(  # type: ignore # pylint: 
disable=no-member
+            self, block, write_buffer_index, storage_scope
+        )
+
     ########## Schedule: Compute location ##########
 
     def compute_inline(self, block: BlockRV) -> None:
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 3fa0c63..d4e4728 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self);
 const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* 
root_block,
                                     GlobalVar* result_g_var);
 
+/*!
+ * \brief Get the root node of the sref tree, which is the root block of the 
PrimFunc.
+ * \param sref The given sref.
+ * \return The root node of the sref tree which contains the given node.
+ */
+StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
+
 /******** Scope ********/
 /*!
  * \brief Checks if scope the specified sref is in is a stage-pipeline and 
return it
@@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, 
const StmtSRef& block_sr
 /******** Block-buffer relation ********/
 
 /*!
- * \brief Get the BlockRealize of the single child block of the block or loop 
specified by
- * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple 
child blocks
- * \param self The schedule state
- * \param block The queried block
- * \param n The index of the queried buffer
- * \return The buffer of the n-th write region of the block.
+ * \brief Get the n-th read or write buffer of the given block.
+ * \param self The schedule state.
+ * \param block The queried block.
+ * \param n The index of the queried buffer.
+ * \param is_write A boolean flag to indicate querying write buffer or read 
buffer.
+ * \return The buffer of the n-th read/write region of the block.
  * \throw ScheduleError If the buffer index is out of bound.
  */
-Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int 
n, bool is_write);
 
 /******** Commutative Reducer ********/
 
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index c9f8ff4..3865781 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, 
const StmtSRef& block_sr
 
 /******** Block-buffer relation ********/
 
-Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) 
{
-  class WriteBufferIndexOutOfRangeError : public ScheduleError {
+Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int 
n, bool is_write) {
+  class BufferIndexOutOfRangeError : public ScheduleError {
    public:
-    explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int 
buffer_index)
-        : mod_(std::move(mod)), block_(std::move(block)), 
buffer_index_(buffer_index) {}
+    explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int 
buffer_index, bool is_write)
+        : mod_(std::move(mod)),
+          block_(std::move(block)),
+          buffer_index_(buffer_index),
+          is_write_(is_write) {}
 
     String FastErrorString() const final {
-      return "ScheduleError: The input `buffer_index` is out of range. It is 
required to be in "
-             "range [0, num_write_regions) where `num_write_regions` is the 
number of buffer "
-             "regions written by the block.";
+      if (is_write_) {
+        return "ScheduleError: The input `buffer_index` is out of range. It is 
required to be in "
+               "range "
+               "[0, num_write_regions) where `num_write_regions` is the number 
of buffer regions "
+               "written by the block.";
+      } else {
+        return "ScheduleError: The input `buffer_index` is out of range. It is 
required to be in "
+               "range "
+               "[0, num_read_regions) where `num_read_regions` is the number 
of buffer regions "
+               "read by the block.";
+      }
     }
 
     String DetailRenderTemplate() const final {
       std::ostringstream os;
-      size_t num_writes = block_->writes.size();
-      os << "The block {0} has " << num_writes
-         << " write regions, so `buffer_index` is required to be in [0, " << 
num_writes
+      size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
+      std::string access_type = is_write_ ? "write" : "read";
+      os << "The block {0} has " << num << " " << access_type
+         << " regions, so `buffer_index` is required to be in [0, " << num
          << "). However, the input `buffer_index` is " << buffer_index_
-         << ", which is out of the expected range";
+         << ", which is out of the expected range.";
       return os.str();
     }
 
@@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const 
Block& block, int n) {
     IRModule mod_;
     Block block_;
     int buffer_index_;
+    bool is_write_;
   };
 
-  if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
-    throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
+  const Array<BufferRegion>& access_region = is_write ? block->writes : 
block->reads;
+
+  if (n < 0 || static_cast<int>(access_region.size()) <= n) {
+    throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
   }
-  return block->writes[n]->buffer;
+  return access_region[n]->buffer;
 }
 
 /******** Pattern Matcher ********/
@@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const 
BufferStore& combiner,
   return false;
 }
 
+/******** SRef Tree Related ********/
+StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
+  const StmtSRefNode* p = sref.get();
+  for (; p->parent != nullptr; p = p->parent) {
+  }
+  return GetRef<StmtSRef>(p);
+}
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index cd9aad8..86223e1 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
 }
 
 /******** Schedule: Insert cache stages ********/
+
+BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int 
read_buffer_index,
+                                        const String& storage_scope) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, 
storage_scope);
+  TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
+  this->state_->DebugVerify();
+  return CreateRV<BlockRV>(result);
+}
+
+BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int 
write_buffer_index,
+                                         const String& storage_scope) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::CacheWrite(state_, this->GetSRef(block_rv), 
write_buffer_index, storage_scope);
+  TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
+  this->state_->DebugVerify();
+  return CreateRV<BlockRV>(result);
+}
+
 /******** Schedule: Compute location ********/
 
 void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 0bd902d..e756f9d 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode {
   void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
   void Unroll(const LoopRV& loop_rv) override;
   /******** Schedule: Insert cache stages ********/
+  BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+                    const String& storage_scope) override;
+  BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+                     const String& storage_scope) override;
   /******** Schedule: Compute location ********/
   void ComputeInline(const BlockRV& block) override;
   void ReverseComputeInline(const BlockRV& block) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index be33c2a..412611a 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& 
loop_sref, const IterVar&
  */
 TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
 /******** Schedule: Insert cache stages ********/
+/*!
+ * \brief Create a block that reads a buffer region into a read cache. It 
requires:
+ * 1) There is at most one block who writes the buffer in the scope.
+ * 2) The scope block have stage-pipeline property.
+ * \param self The state of the schedule
+ * \param block_sref The consumer block of the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope.
+ * \return The cache stage block.
+ */
+TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int 
read_buffer_index,
+                           const String& storage_scope);
+/*!
+ * \brief Create a block that writes a buffer region into a write cache. It 
requires:
+ * 1) There is only one block that writes the target buffer.
+ * 2) The scope block have stage-pipeline property.
+ * \param self The state of the schedule
+ * \param block_sref The producer of the buffer
+ * \param write_buffer_index The index of the buffer in block's write region
+ * \param storage_scope The target storage scope
+ * \return The cache stage block.
+ */
+TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, 
int write_buffer_index,
+                            const String& storage_scope);
 /******** Schedule: Compute location ********/
 /*!
  * \brief Inline a block into its consumer(s). It requires:
diff --git a/src/tir/schedule/primitive/block_annotate.cc 
b/src/tir/schedule/primitive/block_annotate.cc
index 937bc7c..06f7ac3 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include "../transform.h"
 #include "../utils.h"
 
 namespace tvm {
@@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public 
ScheduleError {
 void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index, int axis,
                   int factor, int offset) {
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
-  Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), 
buffer_index);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, 
/*is_write=*/true);
   StorageAlignInvalidFactorError::Check(self->mod, factor);
   axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, 
axis);
   NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
new file mode 100644
index 0000000..df54c96
--- /dev/null
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -0,0 +1,781 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Error Classes ********/
+
+class NotSingleWriteBlock : public ScheduleError {
+ public:
+  explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array<StmtSRef> 
write_blocks)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)) {
+    ICHECK_GT(write_blocks.size(), 1);
+    write_blocks_.reserve(write_blocks.size());
+    for (const StmtSRef& block_sref : write_blocks) {
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      write_blocks_.push_back(GetRef<Block>(block));
+    }
+  }
+
+  String FastErrorString() const final {
+    return "ScheduleError: The buffer is allowed to be written by single 
block.";
+  }
+
+  String DetailRenderTemplate() const final {
+    size_t k = write_blocks_.size();
+    return "The buffer " + buffer_->name + " is expected to be written by 
single block, but got " +
+           std::to_string(k) + " blocks who write it.";
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final {
+    return {write_blocks_.begin(), write_blocks_.end()};
+  }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Array<Block> write_blocks_;
+};
+
+/******** Helper Functions/Classes ********/
+
+/*! \brief The auxiliary info used for the insertion point and content of the 
cache stage. */
+struct CacheStageInfo {
+  /*! \brief The buffer to be read. */
+  Buffer read_buffer;
+  /*! \brief The buffer to be written. */
+  Buffer write_buffer;
+  /*! \brief The buffer allocation to be inserted into the block signature. */
+  Buffer alloc;
+  /*! \brief The AST node whose body is where the cache stage should be 
inserted. */
+  StmtSRef loc_sref;
+  /*! \brief The index to insert the cache_read/cache_write stage. */
+  size_t loc_pos;
+  /*! \brief The cache_read/cache_write stage to be inserted. */
+  Stmt cache_stage;
+  /*! \brief The map used for ScheduleStateNode::Replace. */
+  Map<Block, Block> block_reuse;
+};
+
+/*! \brief Return the buffer region realted with the buffer */
+Optional<BufferRegion> GetBufferRegionFromBuffer(const Array<BufferRegion>& 
buffer_regions,
+                                                 const Buffer& buffer) {
+  Optional<BufferRegion> res = NullOpt;
+  for (const auto& region : buffer_regions) {
+    if (region->buffer.same_as(buffer)) {
+      ICHECK(!res.defined());
+      res = region;
+    }
+  }
+  return res;
+}
+
+/*!
+ * \brief Create a loop nest that represents cache copy (cache_read / 
cache_write) from read buffer
+ *        to write buffer.
+ * \note This function will store the stmt with loop nesting to the 
CacheStageInfo, but only return
+ *        the inside block.
+ * \param cache_region The cached copy region.
+ * \param info The cache stage information, which will be updated in the 
function.
+ * \param storage_scope The storage scope of the cached buffer (only used in 
naming here)
+ * \returns A block indicating the body of the loop nesting.
+ */
+Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
+                     const String& storage_scope) {
+  // loop variables
+  std::vector<Var> loop_vars;
+  // bindings in block realize
+  std::vector<PrimExpr> iter_values;
+  // Create loop vars and block vars' binding_value
+  for (const Range& axis_range : cache_region->region) {
+    Var loop_var("ax" + std::to_string(loop_vars.size()));
+    loop_vars.push_back(loop_var);
+    iter_values.push_back(axis_range->min + loop_var);
+  }
+  // block variables
+  Array<IterVar> block_vars;
+  // block access region for read/write buffers
+  Region access_region;
+  // indices used in block body
+  Array<PrimExpr> access_indices;
+  // Create block vars, block's accessed region and accessing indices
+  for (const PrimExpr& dim : cache_region->buffer->shape) {
+    Var var("v" + std::to_string(access_indices.size()));
+    block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim),
+                                 /*var=*/var,
+                                 /*IterVarType=*/kDataPar));
+    access_indices.push_back(var);
+    access_region.push_back(Range::FromMinExtent(var, 1));
+  }
+
+  // Create the body block:
+  //   reads = [read_buffer[access_region]]
+  //   writes = [write_buffer[access_region]]
+  //     write_buffer[access_indices] = read_buffer[access_indices]
+  Block block(
+      /*iter_vars=*/std::move(block_vars),
+      /*reads=*/{BufferRegion(info->read_buffer, access_region)},
+      /*writes=*/{BufferRegion(info->write_buffer, access_region)},
+      /*name_hint=*/cache_region->buffer->name + "_" + storage_scope,
+      /*body=*/
+      BufferStore(info->write_buffer, BufferLoad(info->read_buffer, 
access_indices),
+                  access_indices),
+      /*init=*/NullOpt,
+      /*alloc_buffers=*/{},
+      /*match_buffers=*/{},
+      /*annotations=*/{});
+  // Create the block realize node
+  Stmt body = BlockRealize(/*values=*/iter_values,
+                           /*predicate=*/Bool(true),
+                           /*block=*/block);
+  // Create surrounding loops
+  for (size_t i = loop_vars.size(); i >= 1; --i) {
+    body = For(/*loop_var=*/loop_vars[i - 1],
+               /*min=*/0,
+               /*extent=*/cache_region->region[i - 1]->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  info->cache_stage = std::move(body);
+  return block;
+}
+
+/*!
+ * \brief Insert the cache_read/cache_write stage into the specific position
+ * \param stmt A sequence of statements or a single statement that the new 
stage is inserted in
+ * \param pos The position where the cache stage is inserted
+ * \param stage The stage to be inserted
+ * \return A SeqStmt, the result after insertion
+ */
+SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
+  if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
+    ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
+    result->seq.insert(result->seq.begin() + pos, stage);
+    return SeqStmt(result);
+  }
+  if (pos == 0) {
+    return SeqStmt({stage, stmt});
+  }
+  ICHECK_EQ(pos, 1);
+  return SeqStmt({stmt, stage});
+}
+
+/*!
+ * \brief Get the only writer block of the input buffer in a given scope block.
+ * \param self The state of the schedule
+ * \param scope_sref The scope block where the write is considered
+ * \param buffer The queried buffer
+ * \return The sref of the only writer of the input buffer in the given scope,
+ *         or `NullOpt` if no block writes it in the scope.
+ * \throw NotSingleWriteBlock if there are more than one intrested block.
+ */
+Optional<StmtSRef> GetOnlyWriteBlock(ScheduleState self, const StmtSRef& 
scope_sref,
+                                     const Buffer& buffer) {
+  BlockScope scope = self->GetBlockScope(scope_sref);
+  auto it = scope->buffer_writers.find(buffer);
+  if (it == scope->buffer_writers.end()) {
+    return NullOpt;
+  } else {
+    const Array<StmtSRef>& block_srefs = it->second;
+    ICHECK(!block_srefs.empty());
+    if (block_srefs.size() > 1) {
+      throw NotSingleWriteBlock(self->mod, buffer, block_srefs);
+    }
+    return block_srefs[0];
+  }
+}
+
+/*!
+ * \brief Get the buffer region under the sref tree path [dom_low_inclusive, 
dom_high_exclusive)
+ * \param self The state of the schedule.
+ * \param buffer_region The buffer region to be analyzed.
+ * \param block_sref The sref of the block related to the region.
+ * \param dom_low_inclusive The lowest node in the sref tree path.
+ * \param dom_high_exclusive The highest node in the sref tree path.
+ * \return The relaxed buffer region.
+ */
+BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& 
buffer_region,
+                               const StmtSRef& block_sref, const StmtSRef& 
dom_low_inclusive,
+                               const StmtSRef& dom_high_exclusive) {
+  BlockRealize realize = GetBlockRealize(self, block_sref);
+  Map<Var, PrimExpr> binding = GetBindings(realize);
+  const Buffer& buffer = buffer_region->buffer;
+  Array<arith::IntSet> int_sets =
+      arith::EvalSet(Substitute(buffer_region->region, binding),
+                     AsIntSet(LoopDomainOfSRefTreePath(
+                         /*low_inclusive=*/dom_low_inclusive,
+                         /*high_exclusive=*/dom_high_exclusive,
+                         
/*extra_relax_scope=*/runtime::StorageScope::Create(buffer.scope()))));
+  ICHECK_EQ(buffer_region->region.size(), int_sets.size());
+
+  Region region;
+  region.reserve(int_sets.size());
+  for (size_t i = 0; i < int_sets.size(); ++i) {
+    region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, 
buffer->shape[i])));
+  }
+  return BufferRegion(buffer, region);
+}
+
+/*! \brief Detect the insertion position of the new cache stage */
+class CacheLocDetector : public StmtVisitor {
+ public:
+  /*!
+   * \brief Detect the insertion position of the cache stage, and write the 
position into the
+   * CacheStageInfo \param self The state of the schedule \param block_sref 
The sref of the unique
+   * writer block of the buffer being applied cache_read or cache_write \param 
scope_sref The sref
+   * of the scope block of the cached block \param info The cache stage info.
+   */
+  static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
+                     const StmtSRef& scope_sref, CacheStageInfo* info) {
+    std::vector<StmtSRef> related_blocks;
+    for (const Dependency& def : 
self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
+      if (def->kind == DepKind::kRAW) {
+        related_blocks.push_back(def->dst);
+      }
+    }
+    if (!related_blocks.empty()) {
+      CacheLocDetector detector(self, block_sref, scope_sref, related_blocks);
+      detector(GetRef<Stmt>(scope_sref->stmt));
+      info->loc_sref = detector.loc_sref_;
+      info->loc_pos = detector.loc_pos_;
+    } else {
+      info->loc_sref = scope_sref;
+      const auto* body = 
scope_sref->StmtAs<BlockNode>()->body.as<SeqStmtNode>();
+      info->loc_pos = body == nullptr ? 1 : body->size();
+    }
+  }
+
+ private:
+  /*!
+   * \brief Constructor
+   * \param self The state of the schedule
+   * \param block_sref The sref of the unique writer block of the buffer being 
applied cache_read or
+   * cache_write \param scope_sref The sref of the scope block of the cached 
block \param
+   * related_blocks Producer blocks for cache_write, or consumer blocks for 
cache_read
+   */
+  CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const 
StmtSRef& scope_sref,
+                   const std::vector<StmtSRef>& related_blocks)
+      : self_(self),
+        block_sref_(block_sref),
+        scope_sref_(scope_sref),
+        related_blocks_(related_blocks) {}
+
+  void VisitStmt_(const SeqStmtNode* seq_stmt) final {
+    bool previous_visited_block = visited_block_;
+    bool previous_visited_related = visited_related_;
+    visited_block_ = visited_related_ = false;
+
+    int pos = -1;
+    for (size_t i = 0; i < seq_stmt->size(); ++i) {
+      if (loc_pos_ != -1) {
+        break;
+      }
+      VisitStmt(seq_stmt->seq[i]);
+      // `pos` can be assigned only once when we visited `block_sref`
+      if (visited_block_ && visited_related_ && pos == -1) {
+        // The offset of insert position from the block
+        pos = i;
+      }
+    }
+    visited_block_ = visited_block_ || previous_visited_block;
+    visited_related_ = visited_related_ || previous_visited_related;
+    // Only we visited the writing block and any one of the related blocks
+    // That means that we have found the lowest ancestor
+    // of the block and any one of the related ones
+    if (visited_block_ && visited_related_ && loc_pos_ == -1) {
+      loc_pos_ = pos;
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    // Only visit the current scope under buffer writer's parent block
+    if (block == scope_sref_->stmt) {
+      // The block vistied is the current parent scope
+      StmtVisitor::VisitStmt_(block);
+      // Handling cache_read for input buffer
+      if (visited_block_ && visited_related_ && !loc_sref_.defined()) {
+        loc_sref_ = self_->stmt2ref.at(block);
+        if (loc_pos_ == -1) {
+          loc_pos_ = 1;
+        }
+      }
+      return;
+    }
+    // Update `visited_block`
+    if (block_sref_->stmt == block) {
+      visited_block_ = true;
+      return;
+    }
+    // Update `visited_related`
+    for (const StmtSRef& related_block : related_blocks_) {
+      if (related_block->stmt == block) {
+        visited_related_ = true;
+        return;
+      }
+    }
+  }
+
+  void VisitStmt_(const ForNode* loop) final {
+    StmtVisitor::VisitStmt_(loop);
+    if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_ 
!= -1) {
+      loc_sref_ = self_->stmt2ref.at(loop);
+    }
+  }
+
+ private:
+  /*! \brief The schedule class */
+  const ScheduleState self_;
+  /*! \brief The dominate block which write the buffer */
+  const StmtSRef& block_sref_;
+  /*! \brief The parent scope of the dominate block */
+  const StmtSRef& scope_sref_;
+  /*! \brief Producer blocks for cache_write and consumer blocks for 
cache_read */
+  const std::vector<StmtSRef>& related_blocks_;
+  /*! \brief The flag whether we have visited the dominate block */
+  bool visited_block_{false};
+  /*! \brief The flag whether we have visited at least one related blocks */
+  bool visited_related_{false};
+  /*! \brief The AST node whose body is where the cache stage should be 
inserted */
+  StmtSRef loc_sref_{nullptr};
+  /*! \brief The index to insert the cache_read/cache_write stage */
+  int loc_pos_{-1};
+};
+
+/*! \brief Mutator for CacheRead. */
+class CacheReadRewriter : public StmtExprMutator {
+ public:
+  /*!
+   * \brief Rewrite the AST and add a cache_read stage with the information 
provided
+   * \param scope_sref The parent scope of this mutation
+   * \param info The cache stage information
+   * \return The new AST rooting at the original parent scope
+   */
+  static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) {
+    CacheReadRewriter rewriter(scope_sref, info);
+    return rewriter(GetRef<Stmt>(scope_sref->stmt));
+  }
+
+ private:
+  explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info)
+      : scope_sref_(scope_sref), info_(info) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt stmt = StmtMutator::VisitStmt_(loop);
+    // Check the insertion point
+    if (loop == info_->loc_sref->stmt) {
+      // Insert cache stage into the loop if it is the right place
+      ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
+      n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+      stmt = Stmt(n);
+    }
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Block old_stmt = GetRef<Block>(block);
+    // We don't mutate the block which generates info->read_buffer
+    if (block != scope_sref_->stmt &&
+        GetBufferRegionFromBuffer(block->writes, 
info_->read_buffer).defined()) {
+      return std::move(old_stmt);
+    }
+    // Mutate the body
+    Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
+    // Check the insertion point
+    if (block == info_->loc_sref->stmt) {
+      // Insert cache stage into the block if it is the right place
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+      n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+      stmt = Block(n);
+    }
+    // Check if it is the block corresponding to the parent scope
+    if (block == scope_sref_->stmt) {
+      // If so, put buffer allocation on the parent scope
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+      n->alloc_buffers.push_back(info_->alloc);
+      stmt = Block(n);
+    } else {
+      // Otherwise, update read regions and match_buffers
+      Array<BufferRegion> reads =
+          ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
+      Array<MatchBufferRegion> match_buffers =
+          ReplaceBuffer(block->match_buffers, info_->read_buffer, 
info_->write_buffer);
+      if (!reads.same_as(block->reads) || 
!match_buffers.same_as(block->match_buffers)) {
+        ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+        n->reads = std::move(reads);
+        n->match_buffers = std::move(match_buffers);
+        stmt = Block(n);
+      }
+    }
+    info_->block_reuse.Set(old_stmt, stmt);
+    return std::move(stmt);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    if (load->buffer.same_as(info_->read_buffer)) {
+      ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
+      n->buffer = info_->write_buffer;
+      return PrimExpr(n);
+    }
+    return ExprMutator::VisitExpr_(load);
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* load) final {
+    if (load->buffer_var.same_as(info_->read_buffer->data)) {
+      ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
+      n->buffer_var = info_->write_buffer->data;
+      return PrimExpr(n);
+    }
+    return ExprMutator::VisitExpr_(load);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    if (op == info_->read_buffer->data.get()) {
+      return info_->write_buffer->data;
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  /*! \brief The parent scope of the insertion */
+  const StmtSRef& scope_sref_;
+  /*! \brief The info for inserting cache stage */
+  CacheStageInfo* info_;
+};
+
+/*! \brief Mutator for CacheWrite */
+class CacheWriteRewriter : public StmtExprMutator {
+ public:
+  /*!
+   * \brief Rewrite the AST and add a cache_write stage with the information 
provided.
+   * \param scope_sref The parent scope of this mutation.
+   * \param writer_block_sref The only writer block in the scope.
+   * \param info The cache stage information.
+   * \return The new AST rooting at the original parent scope.
+   */
+  static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& 
writer_block_sref,
+                      CacheStageInfo* info) {
+    CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info);
+    return rewriter(GetRef<Stmt>(scope_sref->stmt));
+  }
+
+ private:
+  explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& 
writer_block_sref,
+                              CacheStageInfo* info)
+      : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), 
info_(info) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt stmt = StmtMutator::VisitStmt_(loop);
+    // Check the insertion point
+    if (loop == info_->loc_sref->stmt) {
+      // Insert cache stage into the loop if it is the right place
+      ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
+      n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+      stmt = Stmt(n);
+    }
+    return stmt;
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Block old_stmt = GetRef<Block>(block);
+    // We only mutate the block which generates info->write_buffer
+    if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && 
!under_writer_block_) {
+      return std::move(old_stmt);
+    }
+
+    // Mutate the body
+    bool under_scope = under_writer_block_ || block == 
writer_block_sref_->stmt;
+    std::swap(under_scope, under_writer_block_);
+    Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
+    std::swap(under_scope, under_writer_block_);
+
+    // Find the insertion point
+    if (block == info_->loc_sref->stmt) {
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+      n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
+      stmt = Block(n);
+    }
+    // Put buffer allocation on the parent scope
+    if (block == scope_sref_->stmt) {
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+      n->alloc_buffers.push_back(info_->alloc);
+      stmt = Block(n);
+    } else {
+      // Since cache_write changes the block, we need to update the buffer it 
writes
+      auto writes = ReplaceBuffer(block->writes, info_->write_buffer, 
info_->read_buffer);
+      auto reads = ReplaceBuffer(block->reads, info_->write_buffer, 
info_->read_buffer);
+      auto match_buffers =
+          ReplaceBuffer(block->match_buffers, info_->write_buffer, 
info_->read_buffer);
+      if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
+          !match_buffers.same_as(block->match_buffers)) {
+        ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+        n->writes = std::move(writes);
+        n->reads = std::move(reads);
+        n->match_buffers = std::move(match_buffers);
+        stmt = Block(n);
+      }
+    }
+    info_->block_reuse.Set(old_stmt, stmt);
+    return std::move(stmt);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore stmt = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+    if (stmt->buffer.same_as(info_->write_buffer)) {
+      auto n = CopyOnWrite(stmt.get());
+      n->buffer = info_->read_buffer;
+      return Stmt(n);
+    } else {
+      return std::move(stmt);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    if (load->buffer.same_as(info_->write_buffer)) {
+      ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
+      n->buffer = info_->read_buffer;
+      return PrimExpr(n);
+    }
+    return ExprMutator::VisitExpr_(load);
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* load) final {
+    if (load->buffer_var.same_as(info_->write_buffer->data)) {
+      ObjectPtr<LoadNode> n = make_object<LoadNode>(*load);
+      n->buffer_var = info_->read_buffer->data;
+      return PrimExpr(n);
+    }
+    return ExprMutator::VisitExpr_(load);
+  }
+
+  Stmt VisitStmt_(const StoreNode* store) final {
+    if (store->buffer_var.same_as(info_->write_buffer->data)) {
+      ObjectPtr<StoreNode> n = make_object<StoreNode>(*store);
+      n->buffer_var = info_->read_buffer->data;
+      return Stmt(n);
+    }
+    return StmtMutator::VisitStmt_(store);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    if (op == info_->write_buffer->data.get()) {
+      return info_->read_buffer->data;
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  /*! \brief The parent scope of the insertion. */
+  const StmtSRef& scope_sref_;
+  /*! \brief The parent scope of the insertion. */
+  const StmtSRef& writer_block_sref_;
+  /*! \brief The info for inserting cache stage. */
+  CacheStageInfo* info_;
+  /*! \brief Whether the current node is under the given block. */
+  bool under_writer_block_{false};
+};
+
+/******** Implementation ********/
+
+StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int 
read_buffer_index,
+                   const String& storage_scope) {
+  /*!
+   * Check:
+   *   - The index is in the array of block reading region
+   *   - There is at most one block who write the buffer in the scope
+   *
+   * Mutate:
+   *   - Allocate new cache buffer under the current scope.
+   *   - Find the lowest ancestor of the block and ANY ONE of the consumers 
blocks.
+   *   - Copy the buffer with the consumed region.
+   */
+
+  // Step 1. Check index, getting the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  Buffer read_buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, 
/*is_write=*/false);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/true);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+
+  // Step 2. Creat CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = read_buffer;
+  // Create the corresponding buffer to be written, i.e. result of cache_read
+  info.write_buffer = WithScope(read_buffer, storage_scope);
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+
+  // Step 3. Update cache stage info.
+  BufferRegion cache_region{nullptr};
+  if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, 
scope_sref, read_buffer)) {
+    // Case 1. The buffer is written inside the block.
+    StmtSRef write_block_sref = _write_block_sref.value();
+    const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, 
write_block_sref);
+    // Find the producing region
+    BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, 
read_buffer).value();
+    StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+    // Detect insert position
+    CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+    cache_region = RelaxBufferRegion(self, region, write_block_sref, 
parent_sref, info.loc_sref);
+  } else {
+    // Case 2. The buffer is the input block for the scope.
+    info.loc_sref = scope_sref;
+    info.loc_pos = 0;
+    if (Optional<BufferRegion> region =
+            GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
+      cache_region = region.value();
+    } else {
+      cache_region = BufferRegion::FullRegion(read_buffer);
+    }
+  }
+
+  // Step 4. Making new cache stage block and rewrite readers.
+  Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, 
/*info=*/&info,
+                                          /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, 
/*info=*/&info);
+
+  // Step 5. Replacing and updating flags.
+  self->Replace(scope_sref, new_scope, info.block_reuse);
+  StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
+  self->UpdateAffineFlag(result_block_sref);
+  BlockInfo& block_info = self->block_info[result_block_sref];
+  block_info.region_cover = true;
+  block_info.scope->stage_pipeline = true;
+  return result_block_sref;
+}
+
+StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int 
write_buffer_index,
+                    const String& storage_scope) {
+  /*!
+   * Check:
+   *   - The index is in the array of block reading region
+   *   - There is only one block who write the buffer in the scope
+   *
+   * Mutate:
+   *   - Allocate new cache buffer under the current scope.
+   *   - Find the lowest ancestor of the block and ANY ONE of the producer 
blocks.
+   *   - Copy the buffer with the consumed region.
+   */
+  // Step 1. Checking index, getting the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  Buffer write_buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, 
/*is_write=*/true);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/true);
+
+  // Step 2. Creating CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = WithScope(write_buffer, storage_scope);
+  // Create the corresponding buffer to be written, i.e. result of cache_write
+  info.write_buffer = write_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.read_buffer;
+
+  // Step 3. Check the only writer block.
+  ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, 
write_buffer).get());
+
+  // Step 4. Find the producing region and insert position
+  BufferRegion region = GetBufferRegionFromBuffer(block->writes, 
write_buffer).value();
+  StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
+  // Detect insert position
+  CacheLocDetector::Detect(self, block_sref, scope_sref, &info);
+  BufferRegion cache_region =
+      RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
+
+  // Step 5. Making new cache stage block and rewrite readers.
+  Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, 
/*info=*/&info,
+                                           /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref,
+                                               
/*writer_block_sref=*/block_sref, /*info=*/&info);
+
+  // Step 6. Replacing and updating flags.
+  self->Replace(scope_sref, new_scope, info.block_reuse);
+  StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get());
+  self->UpdateAffineFlag(result_block_sref);
+  BlockInfo& block_info = self->block_info[result_block_sref];
+  block_info.region_cover = true;
+  block_info.scope->stage_pipeline = true;
+  return result_block_sref;
+}
+
+/******** Instruction Registration ********/
+
+struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
+  static constexpr const char* kName = "CacheRead";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 2;
+  static constexpr size_t kNumDecisions = 0;
+
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer 
read_buffer_index,
+                                         String storage_scope) {
+    return sch->CacheRead(block, read_buffer_index->value, storage_scope);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block, Integer 
read_buffer_index,
+                                 String storage_scope) {
+    PythonAPICall py("cache_read");
+    py.Input("block", block);
+    py.Input("read_buffer_index", read_buffer_index->value);
+    py.Input("storage_scope", storage_scope);
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct CacheWriteTraits : public UnpackedInstTraits<CacheWriteTraits> {
+  static constexpr const char* kName = "CacheWrite";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 2;
+  static constexpr size_t kNumDecisions = 0;
+
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer 
write_buffer_index,
+                                         String storage_scope) {
+    return sch->CacheWrite(block, write_buffer_index->value, storage_scope);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block, Integer 
write_buffer_index,
+                                 String storage_scope) {
+    PythonAPICall py("cache_write");
+    py.Input("block", block);
+    py.Input("write_buffer_index", write_buffer_index->value);
+    py.Input("storage_scope", storage_scope);
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits);
+TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits);
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index d24cdc6..fd30b02 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -141,6 +141,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize")
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method<Schedule>(&ScheduleNode::Bind);
 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method<Schedule>(&ScheduleNode::Unroll);
 /******** (FFI) Insert cache stages ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead")
+    .set_body_method<Schedule>(&ScheduleNode::CacheRead);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite")
+    .set_body_method<Schedule>(&ScheduleNode::CacheWrite);
 /******** (FFI) Compute location ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
     .set_body_method<Schedule>(&ScheduleNode::ComputeInline);
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 9a9b974..799806b 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -1029,6 +1029,24 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& 
self, const StmtSRef& bl
           Bool(info.scope->stage_pipeline)};
 }
 
+TVM_DLL void ScheduleStateNode::UpdateAffineFlag(const StmtSRef& scope_sref) {
+  auto it = this->block_info.find(scope_sref);
+  ICHECK(it != this->block_info.end()) << "Cannot find the block info of the 
given block.";
+  BlockInfo& info = it->second;
+
+  bool is_root_block = scope_sref->parent == nullptr;
+  if (is_root_block) {
+    info.affine_binding = true;
+  } else {
+    BlockRealize realize = GetBlockRealize(GetRef<ScheduleState>(this), 
scope_sref);
+    arith::Analyzer analyzer;
+    StmtSRef parent_sref = GetRef<StmtSRef>(scope_sref->parent);
+    info.affine_binding = IsAffineBinding(/*realize=*/realize,
+                                          
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref),
+                                          /*analyzer=*/&analyzer);
+  }
+}
+
 /**************** FFI ****************/
 
 TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index af4a658..f429a91 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -166,6 +166,29 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) {
 }
 
 /******** Schedule: Insert cache stages ********/
+BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int 
read_buffer_index,
+                                      const String& storage_scope) {
+  BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, 
read_buffer_index, storage_scope);
+
+  static const InstructionKind& kind = InstructionKind::Get("CacheRead");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{Integer(read_buffer_index), 
storage_scope},
+                                      /*outputs=*/{result}));
+  return result;
+}
+
+BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int 
write_buffer_index,
+                                       const String& storage_scope) {
+  BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, 
write_buffer_index, storage_scope);
+
+  static const InstructionKind& kind = InstructionKind::Get("CacheWrite");
+  trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
+                                      /*inputs=*/{block_rv},
+                                      /*attrs=*/{Integer(write_buffer_index), 
storage_scope},
+                                      /*outputs=*/{result}));
+  return result;
+}
 
 /******** Schedule: Compute location ********/
 
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 48dadbc..a6b5251 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -71,6 +71,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void Bind(const LoopRV& loop_rv, const String& thread_axis) final;
   void Unroll(const LoopRV& loop_rv) final;
   /******** Schedule: Insert cache stages ********/
+  BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
+                    const String& storage_scope) final;
+  BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
+                     const String& storage_scope) final;
   /******** Schedule: Compute location ********/
   void ComputeInline(const BlockRV& block_rv) final;
   void ReverseComputeInline(const BlockRV& block_rv) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index f27e0f6..da376fd 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -19,6 +19,8 @@
 
 #include "./transform.h"
 
+#include "./utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -31,5 +33,43 @@ Block WithAnnotation(const BlockNode* block, const String& 
attr_key, const Objec
   return Block(new_block);
 }
 
+/******** Buffer Related ********/
+Buffer WithScope(const Buffer& buffer, const String& scope) {
+  ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+  ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
+  const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, 
PointerTypeNode);
+  new_var->type_annotation = PointerType(ptr_type->element_type, scope);
+  new_buffer->data = Var(new_var->name_hint + "_" + scope, 
new_var->type_annotation);
+  new_buffer->name = buffer->name + "_" + scope;
+  return Buffer(new_buffer);
+}
+
+Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& 
source,
+                                  const Buffer& target) {
+  regions.MutateByApply([&source, &target](BufferRegion region) -> 
BufferRegion {
+    if (region->buffer.same_as(source)) {
+      ObjectPtr<BufferRegionNode> n = 
make_object<BufferRegionNode>(*region.get());
+      n->buffer = target;
+      return BufferRegion(n);
+    }
+    return region;
+  });
+  return regions;
+}
+
+Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, 
const Buffer& source,
+                                       const Buffer& target) {
+  match_buffers.MutateByApply([&source,
+                               &target](MatchBufferRegion match_buffer) -> 
MatchBufferRegion {
+    if (match_buffer->source->buffer.same_as(source)) {
+      ObjectPtr<MatchBufferRegionNode> n = 
make_object<MatchBufferRegionNode>(*match_buffer.get());
+      n->source = BufferRegion(target, n->source->region);
+      return MatchBufferRegion(n);
+    }
+    return match_buffer;
+  });
+  return match_buffers;
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 5348382..85cce9d 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -35,6 +35,35 @@ namespace tir {
  */
 Block WithAnnotation(const BlockNode* block, const String& attr_key, const 
ObjectRef& attr_value);
 
+/******** Buffer Related ********/
+
+/*!
+ * \brief Create a new buffer by changing the storage scope.
+ * \param buffer The given buffer.
+ * \param scope The target storage scope.
+ * \return The new buffer with target storage scope.
+ */
+Buffer WithScope(const Buffer& buffer, const String& scope);
+
+/*!
+ * \brief Replaces the buffer within the specific sequence of regions
+ * \param regions The regions whose buffers are to be replaced
+ * \param source The buffer to be replaced
+ * \param target The buffer to be replaced to
+ * \return The new sequence of regions after replacement
+ */
+Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& 
source,
+                                  const Buffer& target);
+
+/*!
+ * \brief Replaces the buffer within the specific sequence of match_buffers
+ * \param match_buffers The match_buffers whose buffers are to be replaced
+ * \param source The buffer to be replaced
+ * \param target The buffer to be replaced to
+ * \return The new sequence of match_buffers after replacement
+ */
+Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, 
const Buffer& source,
+                                       const Buffer& target);
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index 8ccf8da..c2f4301 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -42,6 +42,7 @@
 #include "./error.h"
 #include "./instruction_traits.h"
 #include "./primitive.h"
+#include "./transform.h"
 
 namespace tvm {
 namespace tir {
diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py 
b/tests/python/unittest/test_tir_schedule_cache_read_write.py
new file mode 100644
index 0000000..d7eb8d8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py
@@ -0,0 +1,677 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+########## Function before schedule ##########
+
+
[email protected]
+def elementwise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def access_under_scope(b: ty.handle, c: ty.handle) -> None:
+    A = tir.alloc_buffer((128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+
+    with tir.block([8, 8], "scope") as [i, j]:
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "A") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                A[vi, vj] = 1.0
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                B[vi, vj] = A[vi, vj] + 1.0
+
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> 
None:
+    A = tir.match_buffer(a, (128, 128), dtype="float16")
+    B = tir.match_buffer(b, (128, 128), dtype="float16")
+    C = tir.match_buffer(c, (128, 128), dtype="float16")
+    D = tir.match_buffer(d, (128, 128), dtype="float16")
+
+    with tir.block([128, 128], "load_store") as [vi, vj]:
+        tir.reads(A[vi, vj])
+        tir.writes(D[vi, vj])
+        D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj)
+    with tir.block([8, 8], "opaque") as [vi, vj]:
+        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                B.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A.data,
+                    vi * 2048 + vj * 16,
+                    128,
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+    with tir.block([8, 8], "match_buffer") as [vi, vj]:
+        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        A0 = tir.match_buffer(
+            A[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        C0 = tir.match_buffer(
+            C[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                C0.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A0.data,
+                    A0.elem_offset,
+                    A0.strides[0],
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+
+
[email protected]
+def func_multi_consumer() -> None:
+    A = tir.alloc_buffer((128))
+    B = tir.alloc_buffer((128))
+    C = tir.alloc_buffer((128))
+    for i in tir.grid(8):
+        for j in tir.grid(16):
+            with tir.block([128], "A") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                A[vi] = 1.0
+        for j in tir.grid(16):
+            with tir.block([128], "B") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                B[vi] = A[vi] + 1.0
+    for i in tir.grid(128):
+        with tir.block([128], "C") as [vi]:
+            C[vi] = A[vi]
+
+
[email protected]
+def func_multi_producer() -> None:
+    A = tir.alloc_buffer((128))
+    B = tir.alloc_buffer((128))
+    with tir.block([128], "A0") as [vi]:
+        A[vi] = 1.0
+    with tir.block([128], "A1") as [vi]:
+        A[vi] = 2.0
+    with tir.block([128], "B") as [vi]:
+        B[vi] = A[vi]
+
+
+########## Expected function after cache_read ##########
+
+
[email protected]
+def cache_read_elementwise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    A_global = tir.alloc_buffer((128, 128))
+    B_local = tir.alloc_buffer((128, 128), scope="local")
+    with tir.block([128, 128], "A_global") as [vi, vj]:
+        A_global[vi, vj] = A[vi, vj]
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A_global[vi, vj] * 2.0
+    with tir.block([128, 128], "B_local") as [vi, vj]:
+        B_local[vi, vj] = B[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B_local[vi, vj] + 1.0
+
+
[email protected]
+def cache_read_under_scope(b: ty.handle, c: ty.handle) -> None:
+    A = tir.alloc_buffer((128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    A_global = tir.alloc_buffer((128, 128))
+
+    with tir.block([8, 8], "scope") as [i, j]:
+        A_local = tir.alloc_buffer((128, 128), scope="local")
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "A") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                A[vi, vj] = 1.0
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "A_local") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                A_local[vi, vj] = A[vi, vj]
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "B") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                B[vi, vj] = A_local[vi, vj] + 1.0
+    with tir.block([128, 128], "A_global") as [vi, vj]:
+        A_global[vi, vj] = A[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = A_global[vi, vj] * 2.0
+
+
[email protected]
+def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: 
ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), dtype="float16")
+    B = tir.match_buffer(b, (128, 128), dtype="float16")
+    C = tir.match_buffer(c, (128, 128), dtype="float16")
+    D = tir.match_buffer(d, (128, 128), dtype="float16")
+    A_global = tir.alloc_buffer((128, 128), dtype="float16")
+
+    with tir.block([128, 128], "A_global") as [vi, vj]:
+        A_global[vi, vj] = A[vi, vj]
+    with tir.block([128, 128], "load_store") as [vi, vj]:
+        tir.reads(A_global[vi, vj])
+        tir.writes(D[vi, vj])
+        D.data[vi * 128 + vj] = tir.load("float16", A_global.data, vi * 128 + 
vj)
+    with tir.block([8, 8], "opaque") as [vi, vj]:
+        tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                B.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A_global.data,
+                    vi * 2048 + vj * 16,
+                    128,
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+    with tir.block([8, 8], "match_buffer") as [vi, vj]:
+        tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        A0 = tir.match_buffer(
+            A_global[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        C0 = tir.match_buffer(
+            C[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                C0.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A0.data,
+                    A0.elem_offset,
+                    A0.strides[0],
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+
+
[email protected]
+def cache_read_multi_consumer() -> None:
+    A = tir.alloc_buffer((128))
+    B = tir.alloc_buffer((128))
+    C = tir.alloc_buffer((128))
+    A_global = tir.alloc_buffer((128))
+    for i in tir.grid(8):
+        for j in tir.grid(16):
+            with tir.block([128], "A") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                A[vi] = 1.0
+        for j in tir.grid(16):
+            with tir.block([128], "A") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                A_global[vi] = A[vi]
+        for j in tir.grid(16):
+            with tir.block([128], "B") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                B[vi] = A_global[vi] + 1.0
+
+    for i in tir.grid(128):
+        with tir.block([128], "C") as [vi]:
+            C[vi] = A_global[vi]
+
+
[email protected]
+def continuous_cache_read(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    B_shared = tir.alloc_buffer((128, 128), scope="shared")
+    B_local = tir.alloc_buffer((128, 128), scope="local")
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "B_shared") as [vi, vj]:
+        B_shared[vi, vj] = B[vi, vj]
+    with tir.block([128, 128], "B_local") as [vi, vj]:
+        B_local[vi, vj] = B_shared[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B_local[vi, vj] + 1.0
+
+
+########## Expected function after cache_write ##########
+
+
[email protected]
+def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    B_global = tir.alloc_buffer((128, 128), scope="local")
+    C_local = tir.alloc_buffer((128, 128))
+    with tir.block([128, 128], "B_global") as [vi, vj]:
+        B_global[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = B_global[vi, vj]
+    with tir.block([128, 128], "C_local") as [vi, vj]:
+        C_local[vi, vj] = B[vi, vj] + 1.0
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = C_local[vi, vj]
+
+
[email protected]
+def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None:
+    A = tir.alloc_buffer((128, 128))
+    B = tir.match_buffer(b, (128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    A_global = tir.alloc_buffer((128, 128))
+
+    with tir.block([8, 8], "scope") as [i, j]:
+        A_local = tir.alloc_buffer((128, 128), scope="local")
+        B_global = tir.alloc_buffer((128, 128))
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "A_local") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                A_local[vi, vj] = 1.0
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "A") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                A_global[vi, vj] = A_local[vi, vj]
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "B_global") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                B_global[vi, vj] = A_global[vi, vj] + 1.0
+        for x, y in tir.grid(16, 16):
+            with tir.block([128, 128], "B_global") as [vi, vj]:
+                tir.bind(vi, i * 16 + x)
+                tir.bind(vj, j * 16 + y)
+                B[vi, vj] = B_global[vi, vj]
+    with tir.block([128, 128], "A_global") as [vi, vj]:
+        A[vi, vj] = A_global[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]
+def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: 
ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128), dtype="float16")
+    B = tir.match_buffer(b, (128, 128), dtype="float16")
+    C = tir.match_buffer(c, (128, 128), dtype="float16")
+    D = tir.match_buffer(d, (128, 128), dtype="float16")
+    D_global = tir.alloc_buffer((128, 128), dtype="float16")
+    B_global = tir.alloc_buffer((128, 128), dtype="float16")
+    C_global = tir.alloc_buffer((128, 128), dtype="float16")
+
+    with tir.block([128, 128], "load_store") as [vi, vj]:
+        tir.reads(A[vi, vj])
+        tir.writes(D_global[vi, vj])
+        D_global.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + 
vj)
+    with tir.block([8, 8], "opaque") as [vi, vj]:
+        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                B_global.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A.data,
+                    vi * 2048 + vj * 16,
+                    128,
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+    with tir.block([8, 8], "match_buffer") as [vi, vj]:
+        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        tir.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+        A0 = tir.match_buffer(
+            A[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        C0 = tir.match_buffer(
+            C_global[
+                vi * 16 : vi * 16 + 16,
+                vj * 16 : vj * 16 + 16,
+            ],
+            (16, 16),
+            "float16",
+            strides=[128, 1],
+            offset_factor=1,
+        )
+        tir.evaluate(
+            tir.tvm_load_matrix_sync(
+                C0.data,
+                16,
+                16,
+                16,
+                vi * 8 + vj,
+                tir.tvm_access_ptr(
+                    tir.type_annotation(dtype="float16"),
+                    A0.data,
+                    A0.elem_offset,
+                    A0.strides[0],
+                    1,
+                    dtype="handle",
+                ),
+                128,
+                "row_major",
+                dtype="handle",
+            )
+        )
+
+    with tir.block([128, 128], "D") as [vi, vj]:
+        D[vi, vj] = D_global[vi, vj]
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = B_global[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = C_global[vi, vj]
+
+
[email protected]
+def cache_write_multi_consumer() -> None:
+    A = tir.alloc_buffer((128))
+    B = tir.alloc_buffer((128))
+    C = tir.alloc_buffer((128))
+    A_global = tir.alloc_buffer((128))
+    for i in tir.grid(8):
+        for j in tir.grid(16):
+            with tir.block([128], "A_global") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                A_global[vi] = 1.0
+        for j in tir.grid(16):
+            with tir.block([128], "A") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                A[vi] = A_global[vi]
+        for j in tir.grid(16):
+            with tir.block([128], "B") as [vi]:
+                tir.bind(vi, i * 16 + j)
+                B[vi] = A[vi] + 1.0
+
+    for i in tir.grid(128):
+        with tir.block([128], "C") as [vi]:
+            C[vi] = A[vi]
+
+
[email protected]
+def continuous_cache_write(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 128))
+    B = tir.alloc_buffer((128, 128))
+    C = tir.match_buffer(c, (128, 128))
+    B_shared = tir.alloc_buffer((128, 128), scope="shared")
+    B_local = tir.alloc_buffer((128, 128), scope="local")
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B_local[vi, vj] = A[vi, vj] * 2.0
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B_shared[vi, vj] = B_local[vi, vj]
+    with tir.block([128, 128], "B") as [vi, vj]:
+        B[vi, vj] = B_shared[vi, vj]
+    with tir.block([128, 128], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+########## Testcases for cache_read ##########
+
+
+def test_cache_read_elementwise():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    block_c = sch.get_block("C")
+    cached_a = sch.cache_read(block_b, 0, "global")
+    cached_b = sch.cache_read(block_c, 0, "local")
+    assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
+    assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
+    assert sch.get(block_b) == sch.get(sch.get_block("B"))
+    assert sch.get(block_c) == sch.get(sch.get_block("C"))
+    tvm.ir.assert_structural_equal(cache_read_elementwise, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_read_under_scope():
+    sch = tir.Schedule(access_under_scope, debug_mask="all")
+    block_b = sch.get_block("B")
+    block_c = sch.get_block("C")
+    sch.cache_read(block_b, 0, "local")
+    sch.cache_read(block_c, 0, "global")
+    tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=access_under_scope)
+
+
+def test_cache_read_opaque_access():
+    sch = tir.Schedule(opaque_access, debug_mask="all")
+    block = sch.get_block("load_store")
+    sch.cache_read(block, 0, "global")
+    tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=opaque_access)
+
+
+def test_cache_read_location():
+    sch = tir.Schedule(func_multi_consumer, debug_mask="all")
+    block_b = sch.get_block("B")
+    sch.cache_read(block_b, 0, "global")
+    tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
+
+
+def test_continuous_cache_read():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_c = sch.get_block("C")
+    sch.cache_read(block_c, 0, "shared")
+    sch.cache_read(block_c, 0, "local")
+    tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_read_fail_multi_producer():
+    sch = tir.Schedule(func_multi_producer, debug_mask="all")
+    block_b = sch.get_block("B")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.cache_read(block_b, 0, "global")
+
+
+def test_cache_read_fail_index_out_of_bound():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.cache_read(block_b, 1, "global")
+
+
+########## Testcases for cache_write ##########
+
+
+def test_cache_write_elementwise():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    block_c = sch.get_block("C")
+    cached_b = sch.cache_write(block_b, 0, "local")
+    cached_c = sch.cache_write(block_c, 0, "global")
+    assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
+    assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
+    assert sch.get(block_b) == sch.get(sch.get_block("B"))
+    assert sch.get(block_c) == sch.get(sch.get_block("C"))
+    tvm.ir.assert_structural_equal(cache_write_elementwise, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_write_under_scope():
+    sch = tir.Schedule(access_under_scope, debug_mask="all")
+    block_a = sch.get_block("A")
+    block_b = sch.get_block("B")
+    block_scope = sch.get_block("scope")
+    sch.cache_write(block_a, 0, "local")
+    sch.cache_write(block_b, 0, "global")
+    sch.cache_write(block_scope, 0, "global")
+    tvm.ir.assert_structural_equal(cache_write_under_scope, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=access_under_scope)
+
+
+def test_cache_write_opaque_access():
+    sch = tir.Schedule(opaque_access, debug_mask="all")
+    block_store = sch.get_block("load_store")
+    block_opaque = sch.get_block("opaque")
+    block_match_buffer = sch.get_block("match_buffer")
+    sch.cache_write(block_store, 0, "global")
+    sch.cache_write(block_opaque, 0, "global")
+    sch.cache_write(block_match_buffer, 0, "global")
+    tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=opaque_access)
+
+
+def test_cache_write_location():
+    sch = tir.Schedule(func_multi_consumer, debug_mask="all")
+    block_a = sch.get_block("A")
+    sch.cache_write(block_a, 0, "global")
+    tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
+
+
+def test_continuous_cache_write():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    sch.cache_write(block_b, 0, "shared")
+    sch.cache_write(block_b, 0, "local")
+    tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_cache_write_fail_multi_producer():
+    sch = tir.Schedule(func_multi_producer, debug_mask="all")
+    block_a0 = sch.get_block("A0")
+    block_a1 = sch.get_block("A1")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.cache_write(block_a0, 0, "global")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.cache_write(block_a1, 0, "global")
+
+
+def test_cache_write_fail_index_out_of_bound():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block_b = sch.get_block("B")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.cache_write(block_b, 1, "global")
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to