This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3af50e0fce [TIR][Transform] Keep the allocate buffers order after 
update buffer allocation location (#13560)
3af50e0fce is described below

commit 3af50e0fcea9a2da327343fe121498353ad912ce
Author: Fred.Jia <[email protected]>
AuthorDate: Fri Dec 9 08:38:34 2022 +0800

    [TIR][Transform] Keep the allocate buffers order after update buffer 
allocation location (#13560)
    
    [TIR][Transform] Keep the allocate buffers order after update buffer 
allocated location
---
 .../plan_update_buffer_allocation_location.cc      | 63 ++++++++++++++++++----
 ...sform_plan_update_buffer_allocation_location.py | 25 +++++++--
 2 files changed, 74 insertions(+), 14 deletions(-)

diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc 
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 90150ebd3c..4c63d3393f 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -48,10 +48,53 @@ class CollectUnmanagedAllocations : public StmtExprVisitor {
   std::unordered_set<const VarNode*> unmanaged_allocations;
 };
 
+/*! \brief Collect the allocate buffer order. */
+class BufferAllocateOrderCollector : public StmtExprVisitor {
+ public:
+  static Array<Buffer> Collect(const PrimFunc& func) {
+    BufferAllocateOrderCollector collector;
+    for (const auto& kv : func->buffer_map) {
+      collector.buffer_alloc_recorder_.push_back(kv.second);
+    }
+    collector(func->body);
+    return std::move(collector.buffer_alloc_recorder_);
+  }
+
+ private:
+  void VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& buffer : op->alloc_buffers) {
+      buffer_alloc_recorder_.push_back(buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (std::find(buffer_alloc_recorder_.begin(), 
buffer_alloc_recorder_.end(), op->buffer) ==
+        buffer_alloc_recorder_.end()) {
+      buffer_alloc_recorder_.push_back(op->buffer);
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    if (std::find(buffer_alloc_recorder_.begin(), 
buffer_alloc_recorder_.end(), op->buffer) ==
+        buffer_alloc_recorder_.end()) {
+      buffer_alloc_recorder_.push_back(op->buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  /*! \brief The buffer allocated order recorder. */
+  Array<Buffer> buffer_alloc_recorder_;
+};
+
 class BufferAllocationLocator : public StmtExprMutator {
  public:
   explicit BufferAllocationLocator(const PrimFunc& func) {
     Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
+    // The buffer_alloc_recorder Array is used to keep the buffer allocation 
order
+    // since the buffer_lca Map is unordered.
+    Array<Buffer> buffer_alloc_recorder = 
BufferAllocateOrderCollector::Collect(func);
     std::unordered_set<const VarNode*> arg_buffer_vars;
     CollectUnmanagedAllocations collector;
     collector(func->body);
@@ -63,16 +106,18 @@ class BufferAllocationLocator : public StmtExprMutator {
       buffer_data_to_buffer_.Set(buffer->data, buffer);
     }
     // create buffers to be allocated at each stmts
-    for (const auto& kv : buffer_lca) {
-      const Buffer& buffer = kv.first;
-      const StmtNode* stmt = kv.second.get();
-      if (arg_buffer_vars.count(buffer->data.get())) {
-        continue;
+    for (const auto& buffer : buffer_alloc_recorder) {
+      auto it = buffer_lca.find(buffer);
+      if (it != buffer_lca.end()) {
+        const StmtNode* stmt = (*it).second.get();
+        if (arg_buffer_vars.count(buffer->data.get())) {
+          continue;
+        }
+        if (!unmanaged_allocations_.count(buffer->data.get())) {
+          alloc_buffers_[stmt].push_back(buffer);
+        }
+        buffer_data_to_buffer_.Set(buffer->data, buffer);
       }
-      if (!unmanaged_allocations_.count(buffer->data.get())) {
-        alloc_buffers_[stmt].push_back(buffer);
-      }
-      buffer_data_to_buffer_.Set(buffer->data, buffer);
     }
   }
 
diff --git 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 34d82f86a4..92e3cbd66e 100644
--- 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++ 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -245,11 +245,13 @@ def test_lower_te():
 
 def test_loop_carried_dependency():
     """The buffer allocation should be above opaque iter var's loop scopes
-    such that buffer accesses with loop carried dependencies are covered."""
+    such that buffer accesses with loop carried dependencies are covered,
+    and the allocate buffer should keep the order."""
 
     @T.prim_func
     def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), 
"int32"]):
         C = T.alloc_buffer([8, 8, 8], dtype="int32")
+        D = T.alloc_buffer([8, 8, 8], dtype="int32")
         for i in T.serial(8):
             for j in T.serial(8):
                 for k in T.serial(8):
@@ -258,10 +260,16 @@ def test_loop_carried_dependency():
                         C[vi, vj, vk] = A[vi, vj, vk] + 1
                 for k in T.serial(8):
                     with T.block("b1"):
+                        vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                        D[vi, vj, vk] = A[vi, vj, vk] + 2
+                for k in T.serial(8):
+                    with T.block("b2"):
                         vi, vk = T.axis.remap("SS", [i, k])
                         vj = T.axis.opaque(8, j)
-                        B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
-                            0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
+                        B[vi, vj, vk] = (
+                            C[vi, vj, vk]
+                            + T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0, 
dtype="int32")
+                            + D[vi, vj, vk]
                         )
 
     @T.prim_func
@@ -271,6 +279,7 @@ def test_loop_carried_dependency():
                 T.reads(A[i, 0:8, 0:8])
                 T.writes(B[i, 0:8, 0:8])
                 C = T.alloc_buffer([8, 8, 8], dtype="int32")
+                D = T.alloc_buffer([8, 8, 8], dtype="int32")
                 for j in T.serial(8):
                     for k in T.serial(8):
                         with T.block("b0"):
@@ -278,10 +287,16 @@ def test_loop_carried_dependency():
                             C[vi, vj, vk] = A[vi, vj, vk] + 1
                     for k in T.serial(8):
                         with T.block("b1"):
+                            vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                            D[vi, vj, vk] = A[vi, vj, vk] + 2
+                    for k in T.serial(8):
+                        with T.block("b2"):
                             vi, vk = T.axis.remap("SS", [i, k])
                             vj = T.axis.opaque(8, j)
-                            B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
-                                0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
+                            B[vi, vj, vk] = (
+                                C[vi, vj, vk]
+                                + T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0, 
dtype="int32")
+                                + D[vi, vj, vk]
                             )
 
     _check(before, after)

Reply via email to