This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 52c3d84040 [TIR] Handle DeclBuffer in StorageRewrite (#15051)
52c3d84040 is described below
commit 52c3d84040daa133a3038a93768ca877ee7b1340
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 8 06:41:49 2023 -0400
[TIR] Handle DeclBuffer in StorageRewrite (#15051)
Allow `tir::DeclBuffer` to appear in the input of the `StorageRewrite`
transform. Any `DeclBuffer` whose backing allocation is rewritten are
updated with the new buffer object. Any `DeclBuffer` whose backing
allocation is unused has been removed by `StorageRewrite` is itself
removed.
This is a subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
---
src/tir/transforms/storage_rewrite.cc | 70 ++++++++++++++++------
.../unittest/test_tir_transform_storage_rewrite.py | 66 ++++++++++++++++++++
2 files changed, 117 insertions(+), 19 deletions(-)
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 240b16aa5b..3ecd0f64bb 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -104,12 +104,14 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
+ all_buffers_accessed_.insert(op->buffer.get());
+
// Add write access.
- const VarNode* buf = op->buffer->data.get();
- auto it = alloc_info_.find(buf);
+ const VarNode* buffer_var = op->buffer->data.get();
+ auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
- scope_[it->second.level].touched.push_back(buf);
+ scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
@@ -128,11 +130,14 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
void VisitExpr_(const BufferLoadNode* op) final {
// Add write access.
StmtExprVisitor::VisitExpr_(op);
- const VarNode* buf = op->buffer->data.get();
- auto it = alloc_info_.find(buf);
+
+ all_buffers_accessed_.insert(op->buffer.get());
+
+ const VarNode* buffer_var = op->buffer->data.get();
+ auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places
other than store.";
- scope_[it->second.level].touched.push_back(buf);
+ scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
@@ -213,6 +218,9 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+ // A record of which Buffer objects have been accessed, to prune
+ // unused DeclBuffer instances.
+ std::unordered_set<const BufferNode*> all_buffers_accessed_;
private:
// Whether already in thread env.
@@ -378,6 +386,7 @@ class StoragePlanRewriter : public StmtExprMutator {
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
+ all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
stmt = operator()(std::move(stmt));
@@ -505,6 +514,20 @@ class StoragePlanRewriter : public StmtExprMutator {
Stmt VisitStmt_(const AllocateNode* op) final { return
this->VisitStmt(op->body); }
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ if (hoisted_buffer_decls_.count(op->buffer.get()) ||
+ !all_buffers_accessed_.count(op->buffer.get())) {
+ return this->VisitStmt(op->body);
+ }
+ auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+
+ if (auto it = alloc_map_.find(op->buffer->data.get()); it !=
alloc_map_.end()) {
+ Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var);
+ node.CopyOnWrite()->buffer = buf;
+ }
+ return std::move(node);
+ }
+
private:
struct StorageEntry {
// The scope that this alloc attaches after
@@ -523,8 +546,9 @@ class StoragePlanRewriter : public StmtExprMutator {
std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry*> merged_children;
- // The replacement allocation, if any.
- Stmt new_alloc;
+ // The replacement Allocate, if any. May also include associated
+ // DeclBuffer statement.
+ std::vector<Stmt> alloc_nest;
// The var expr of new allocation.
Var alloc_var;
// The allocation element type.
@@ -560,13 +584,10 @@ class StoragePlanRewriter : public StmtExprMutator {
};
Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) {
- std::vector<Stmt> nest;
- for (StorageEntry* e : svec) {
- if (e->new_alloc.defined()) {
- nest.push_back(e->new_alloc);
- }
+ for (auto it = svec.rbegin(); it != svec.rend(); it++) {
+ body = MergeNest((*it)->alloc_nest, body);
}
- return MergeNest(nest, body);
+ return body;
}
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
@@ -636,8 +657,13 @@ class StoragePlanRewriter : public StmtExprMutator {
if (all_allocs_identical) {
// simply use the original allocation.
- e->new_alloc = Allocate(e->alloc_var, alloc_type,
e->allocs[0]->extents,
- e->allocs[0]->condition, Evaluate(0));
+ e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type,
e->allocs[0]->extents,
+ e->allocs[0]->condition,
Evaluate(0)));
+ if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
+ e->alloc_nest.push_back(
+ DeclBuffer(RemapBuffer(ptr->buffer, e->alloc_var),
Evaluate(0)));
+ hoisted_buffer_decls_.insert(ptr->buffer.get());
+ }
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
@@ -684,8 +710,8 @@ class StoragePlanRewriter : public StmtExprMutator {
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
- e->new_alloc =
- Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate(0));
+ e->alloc_nest.push_back(
+ Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate(0)));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
@@ -729,7 +755,8 @@ class StoragePlanRewriter : public StmtExprMutator {
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size =
make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits -
1) / type_bits);
- e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size},
const_true(), Evaluate(0));
+ e->alloc_nest.push_back(
+ Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(),
Evaluate(0)));
if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
@@ -996,6 +1023,11 @@ class StoragePlanRewriter : public StmtExprMutator {
std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
// The buffer objects being remapped
std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
+ // Buffers whose DeclBuffer has been hoisted to be adjacent to the new
Allocate location
+ std::unordered_set<const BufferNode*> hoisted_buffer_decls_;
+ // Any buffers that is accessed at some point. DeclBuffer instances
+ // that do not appear in this list may be removed.
+ std::unordered_set<const BufferNode*> all_buffers_accessed_;
// analyzer
arith::Analyzer analyzer_;
};
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index cff76766b3..34de6fcabf 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -860,5 +860,71 @@ class TestNoRewriteOfSharedNonFlatBuffer(BaseCompare):
expected = before
+class TestRewriteDeclBuffer(BaseCompare):
+ """A DeclBuffer node may appear in StorageRewrite's input"""
+
+ def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+ B = T.decl_buffer(16, dtype="float32")
+ C = T.decl_buffer(16, dtype="float32")
+
+ for i in range(16):
+ B[i] = A[i]
+
+ for i in range(16):
+ C[i] = 2.0 * B[i]
+
+ for i in range(16):
+ D[i] = C[i]
+
+ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+ B = T.decl_buffer(16, dtype="float32")
+ C = T.decl_buffer(16, dtype="float32", data=B.data)
+
+ for i in range(16):
+ B[i] = A[i]
+
+ for i in range(16):
+ C[i] = 2.0 * B[i]
+
+ for i in range(16):
+ D[i] = C[i]
+
+
+class TestNoOrphanedDeclBuffer(BaseCompare):
+ """A DeclBuffer of an unused Allocate should be removed
+
+ StorageRewrite removes any allocations that are unused. When it
+ does so, any DeclBuffer that refers to that allocation should also
+ be removed.
+ """
+
+ def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+ B = T.decl_buffer(16, dtype="float32")
+ C = T.decl_buffer(16, dtype="float32")
+ Unused = T.decl_buffer(16, dtype="float32")
+
+ for i in range(16):
+ B[i] = A[i]
+
+ for i in range(16):
+ C[i] = 2.0 * B[i]
+
+ for i in range(16):
+ D[i] = C[i]
+
+ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
+ B = T.decl_buffer(16, dtype="float32")
+ C = T.decl_buffer(16, dtype="float32", data=B.data)
+
+ for i in range(16):
+ B[i] = A[i]
+
+ for i in range(16):
+ C[i] = 2.0 * B[i]
+
+ for i in range(16):
+ D[i] = C[i]
+
+
if __name__ == "__main__":
tvm.testing.main()