This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52c3d84040 [TIR] Handle DeclBuffer in StorageRewrite (#15051)
52c3d84040 is described below

commit 52c3d84040daa133a3038a93768ca877ee7b1340
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 8 06:41:49 2023 -0400

    [TIR] Handle DeclBuffer in StorageRewrite (#15051)
    
    Allow `tir::DeclBuffer` to appear in the input of the `StorageRewrite`
    transform.  Any `DeclBuffer` whose backing allocation is rewritten are
    updated with the new buffer object.  Any `DeclBuffer` whose backing
    allocation is unused has been removed by `StorageRewrite` is itself
    removed.
    
    This is a subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
---
 src/tir/transforms/storage_rewrite.cc              | 70 ++++++++++++++++------
 .../unittest/test_tir_transform_storage_rewrite.py | 66 ++++++++++++++++++++
 2 files changed, 117 insertions(+), 19 deletions(-)

diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 240b16aa5b..3ecd0f64bb 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -104,12 +104,14 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
     scope_.push_back(StmtEntry());
     // visit subexpr
     StmtExprVisitor::VisitStmt_(op);
+    all_buffers_accessed_.insert(op->buffer.get());
+
     // Add write access.
-    const VarNode* buf = op->buffer->data.get();
-    auto it = alloc_info_.find(buf);
+    const VarNode* buffer_var = op->buffer->data.get();
+    auto it = alloc_info_.find(buffer_var);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size());
-      scope_[it->second.level].touched.push_back(buf);
+      scope_[it->second.level].touched.push_back(buffer_var);
 
       ICHECK_EQ(op->buffer->axis_separators.size() + 1, 
it->second.num_physical_dimensions)
           << "Buffer " << op->buffer->name << " is allocated with "
@@ -128,11 +130,14 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
   void VisitExpr_(const BufferLoadNode* op) final {
     // Add write access.
     StmtExprVisitor::VisitExpr_(op);
-    const VarNode* buf = op->buffer->data.get();
-    auto it = alloc_info_.find(buf);
+
+    all_buffers_accessed_.insert(op->buffer.get());
+
+    const VarNode* buffer_var = op->buffer->data.get();
+    auto it = alloc_info_.find(buffer_var);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places 
other than store.";
-      scope_[it->second.level].touched.push_back(buf);
+      scope_[it->second.level].touched.push_back(buffer_var);
 
       ICHECK_EQ(op->buffer->axis_separators.size() + 1, 
it->second.num_physical_dimensions)
           << "Buffer " << op->buffer->name << " is allocated with "
@@ -213,6 +218,9 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
   std::vector<StmtEntry> linear_seq_;
   // The storage scope of each buffer
   std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+  // A record of which Buffer objects have been accessed, to prune
+  // unused DeclBuffer instances.
+  std::unordered_set<const BufferNode*> all_buffers_accessed_;
 
  private:
   // Whether already in thread env.
@@ -378,6 +386,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     finder(stmt);
     this->LivenessAnalysis(finder.linear_seq_);
     this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
+    all_buffers_accessed_ = finder.all_buffers_accessed_;
     this->PrepareNewAlloc();
     // start rewrite
     stmt = operator()(std::move(stmt));
@@ -505,6 +514,20 @@ class StoragePlanRewriter : public StmtExprMutator {
 
   Stmt VisitStmt_(const AllocateNode* op) final { return 
this->VisitStmt(op->body); }
 
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    if (hoisted_buffer_decls_.count(op->buffer.get()) ||
+        !all_buffers_accessed_.count(op->buffer.get())) {
+      return this->VisitStmt(op->body);
+    }
+    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+
+    if (auto it = alloc_map_.find(op->buffer->data.get()); it != 
alloc_map_.end()) {
+      Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var);
+      node.CopyOnWrite()->buffer = buf;
+    }
+    return std::move(node);
+  }
+
  private:
   struct StorageEntry {
     // The scope that this alloc attaches after
@@ -523,8 +546,9 @@ class StoragePlanRewriter : public StmtExprMutator {
     std::vector<const AllocateNode*> allocs;
     // The children of this entry, not including itself.
     std::vector<StorageEntry*> merged_children;
-    // The replacement allocation, if any.
-    Stmt new_alloc;
+    // The replacement Allocate, if any.  May also include associated
+    // DeclBuffer statement.
+    std::vector<Stmt> alloc_nest;
     // The var expr of new allocation.
     Var alloc_var;
     // The allocation element type.
@@ -560,13 +584,10 @@ class StoragePlanRewriter : public StmtExprMutator {
   };
 
   Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) {
-    std::vector<Stmt> nest;
-    for (StorageEntry* e : svec) {
-      if (e->new_alloc.defined()) {
-        nest.push_back(e->new_alloc);
-      }
+    for (auto it = svec.rbegin(); it != svec.rend(); it++) {
+      body = MergeNest((*it)->alloc_nest, body);
     }
-    return MergeNest(nest, body);
+    return body;
   }
   // Remap the index
   PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
@@ -636,8 +657,13 @@ class StoragePlanRewriter : public StmtExprMutator {
 
         if (all_allocs_identical) {
           // simply use the original allocation.
-          e->new_alloc = Allocate(e->alloc_var, alloc_type, 
e->allocs[0]->extents,
-                                  e->allocs[0]->condition, Evaluate(0));
+          e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type, 
e->allocs[0]->extents,
+                                           e->allocs[0]->condition, 
Evaluate(0)));
+          if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
+            e->alloc_nest.push_back(
+                DeclBuffer(RemapBuffer(ptr->buffer, e->alloc_var), 
Evaluate(0)));
+            hoisted_buffer_decls_.insert(ptr->buffer.get());
+          }
           if (IsSpecialTaggedMemory(e->scope)) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             if (info.defined()) {
@@ -684,8 +710,8 @@ class StoragePlanRewriter : public StmtExprMutator {
             combo_size = combo_size + make_const(DataType::Int(32), 1);
           }
           combo_size = analyzer_.Simplify(combo_size);
-          e->new_alloc =
-              Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), 
Evaluate(0));
+          e->alloc_nest.push_back(
+              Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), 
Evaluate(0)));
           if (IsSpecialTaggedMemory(e->scope)) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             if (info.defined()) {
@@ -729,7 +755,8 @@ class StoragePlanRewriter : public StmtExprMutator {
     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
     PrimExpr alloc_size =
         make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 
1) / type_bits);
-    e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, 
const_true(), Evaluate(0));
+    e->alloc_nest.push_back(
+        Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), 
Evaluate(0)));
     if (info.defined()) {
       ICHECK_LE(total_bits, info->max_num_bits)
           << "Allocation exceed bound of memory tag " << e->scope.to_string();
@@ -996,6 +1023,11 @@ class StoragePlanRewriter : public StmtExprMutator {
   std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
   // The buffer objects being remapped
   std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
+  // Buffers whose DeclBuffer has been hoisted to be adjacent to the new 
Allocate location
+  std::unordered_set<const BufferNode*> hoisted_buffer_decls_;
+  // Any buffers that is accessed at some point.  DeclBuffer instances
+  // that do not appear in this list may be removed.
+  std::unordered_set<const BufferNode*> all_buffers_accessed_;
   // analyzer
   arith::Analyzer analyzer_;
 };
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py 
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index cff76766b3..34de6fcabf 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -860,5 +860,71 @@ class TestNoRewriteOfSharedNonFlatBuffer(BaseCompare):
     expected = before
 
 
+class TestRewriteDeclBuffer(BaseCompare):
+    """A DeclBuffer node may appear in StorageRewrite's input"""
+
+    def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+        B = T.decl_buffer(16, dtype="float32")
+        C = T.decl_buffer(16, dtype="float32")
+
+        for i in range(16):
+            B[i] = A[i]
+
+        for i in range(16):
+            C[i] = 2.0 * B[i]
+
+        for i in range(16):
+            D[i] = C[i]
+
+    def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+        B = T.decl_buffer(16, dtype="float32")
+        C = T.decl_buffer(16, dtype="float32", data=B.data)
+
+        for i in range(16):
+            B[i] = A[i]
+
+        for i in range(16):
+            C[i] = 2.0 * B[i]
+
+        for i in range(16):
+            D[i] = C[i]
+
+
+class TestNoOrphanedDeclBuffer(BaseCompare):
+    """A DeclBuffer of an unused Allocate should be removed
+
+    StorageRewrite removes any allocations that are unused.  When it
+    does so, any DeclBuffer that refers to that allocation should also
+    be removed.
+    """
+
+    def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+        B = T.decl_buffer(16, dtype="float32")
+        C = T.decl_buffer(16, dtype="float32")
+        Unused = T.decl_buffer(16, dtype="float32")
+
+        for i in range(16):
+            B[i] = A[i]
+
+        for i in range(16):
+            C[i] = 2.0 * B[i]
+
+        for i in range(16):
+            D[i] = C[i]
+
+    def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+        B = T.decl_buffer(16, dtype="float32")
+        C = T.decl_buffer(16, dtype="float32", data=B.data)
+
+        for i in range(16):
+            B[i] = A[i]
+
+        for i in range(16):
+            C[i] = 2.0 * B[i]
+
+        for i in range(16):
+            D[i] = C[i]
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to