This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3af50e0fce [TIR][Transform] Keep the allocate buffers order after
update buffer allocation location (#13560)
3af50e0fce is described below
commit 3af50e0fcea9a2da327343fe121498353ad912ce
Author: Fred.Jia <[email protected]>
AuthorDate: Fri Dec 9 08:38:34 2022 +0800
[TIR][Transform] Keep the allocate buffers order after update buffer
allocation location (#13560)
[TIR][Transform] Keep the allocate buffers order after update buffer
allocated location
---
.../plan_update_buffer_allocation_location.cc | 63 ++++++++++++++++++----
...sform_plan_update_buffer_allocation_location.py | 25 +++++++--
2 files changed, 74 insertions(+), 14 deletions(-)
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 90150ebd3c..4c63d3393f 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -48,10 +48,53 @@ class CollectUnmanagedAllocations : public StmtExprVisitor {
std::unordered_set<const VarNode*> unmanaged_allocations;
};
+/*! \brief Collect the allocate buffer order. */
+class BufferAllocateOrderCollector : public StmtExprVisitor {
+ public:
+ static Array<Buffer> Collect(const PrimFunc& func) {
+ BufferAllocateOrderCollector collector;
+ for (const auto& kv : func->buffer_map) {
+ collector.buffer_alloc_recorder_.push_back(kv.second);
+ }
+ collector(func->body);
+ return std::move(collector.buffer_alloc_recorder_);
+ }
+
+ private:
+ void VisitStmt_(const BlockNode* op) final {
+ for (const Buffer& buffer : op->alloc_buffers) {
+ buffer_alloc_recorder_.push_back(buffer);
+ }
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ if (std::find(buffer_alloc_recorder_.begin(),
buffer_alloc_recorder_.end(), op->buffer) ==
+ buffer_alloc_recorder_.end()) {
+ buffer_alloc_recorder_.push_back(op->buffer);
+ }
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const BufferStoreNode* op) final {
+ if (std::find(buffer_alloc_recorder_.begin(),
buffer_alloc_recorder_.end(), op->buffer) ==
+ buffer_alloc_recorder_.end()) {
+ buffer_alloc_recorder_.push_back(op->buffer);
+ }
+ StmtExprVisitor::VisitStmt_(op);
+ }
+
+ /*! \brief The buffer allocated order recorder. */
+ Array<Buffer> buffer_alloc_recorder_;
+};
+
class BufferAllocationLocator : public StmtExprMutator {
public:
explicit BufferAllocationLocator(const PrimFunc& func) {
Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
+ // The buffer_alloc_recorder Array is used to keep the buffer allocation
order
+ // since the buffer_lca Map is unordered.
+ Array<Buffer> buffer_alloc_recorder =
BufferAllocateOrderCollector::Collect(func);
std::unordered_set<const VarNode*> arg_buffer_vars;
CollectUnmanagedAllocations collector;
collector(func->body);
@@ -63,16 +106,18 @@ class BufferAllocationLocator : public StmtExprMutator {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// create buffers to be allocated at each stmts
- for (const auto& kv : buffer_lca) {
- const Buffer& buffer = kv.first;
- const StmtNode* stmt = kv.second.get();
- if (arg_buffer_vars.count(buffer->data.get())) {
- continue;
+ for (const auto& buffer : buffer_alloc_recorder) {
+ auto it = buffer_lca.find(buffer);
+ if (it != buffer_lca.end()) {
+ const StmtNode* stmt = (*it).second.get();
+ if (arg_buffer_vars.count(buffer->data.get())) {
+ continue;
+ }
+ if (!unmanaged_allocations_.count(buffer->data.get())) {
+ alloc_buffers_[stmt].push_back(buffer);
+ }
+ buffer_data_to_buffer_.Set(buffer->data, buffer);
}
- if (!unmanaged_allocations_.count(buffer->data.get())) {
- alloc_buffers_[stmt].push_back(buffer);
- }
- buffer_data_to_buffer_.Set(buffer->data, buffer);
}
}
diff --git
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 34d82f86a4..92e3cbd66e 100644
---
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -245,11 +245,13 @@ def test_lower_te():
def test_loop_carried_dependency():
"""The buffer allocation should be above opaque iter var's loop scopes
- such that buffer accesses with loop carried dependencies are covered."""
+ such that buffer accesses with loop carried dependencies are covered,
+ and the allocate buffer should keep the order."""
@T.prim_func
def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8),
"int32"]):
C = T.alloc_buffer([8, 8, 8], dtype="int32")
+ D = T.alloc_buffer([8, 8, 8], dtype="int32")
for i in T.serial(8):
for j in T.serial(8):
for k in T.serial(8):
@@ -258,10 +260,16 @@ def test_loop_carried_dependency():
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ D[vi, vj, vk] = A[vi, vj, vk] + 2
+ for k in T.serial(8):
+ with T.block("b2"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
- B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
- 0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
+ B[vi, vj, vk] = (
+ C[vi, vj, vk]
+ + T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0,
dtype="int32")
+ + D[vi, vj, vk]
)
@T.prim_func
@@ -271,6 +279,7 @@ def test_loop_carried_dependency():
T.reads(A[i, 0:8, 0:8])
T.writes(B[i, 0:8, 0:8])
C = T.alloc_buffer([8, 8, 8], dtype="int32")
+ D = T.alloc_buffer([8, 8, 8], dtype="int32")
for j in T.serial(8):
for k in T.serial(8):
with T.block("b0"):
@@ -278,10 +287,16 @@ def test_loop_carried_dependency():
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ D[vi, vj, vk] = A[vi, vj, vk] + 2
+ for k in T.serial(8):
+ with T.block("b2"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
- B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
- 0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
+ B[vi, vj, vk] = (
+ C[vi, vj, vk]
+ + T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0,
dtype="int32")
+ + D[vi, vj, vk]
)
_check(before, after)