This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 897019d [Pass][Bugfix] Disable re-use of non-flat buffers in
StorageRewrite. (#10787)
897019d is described below
commit 897019df6a86720f0157a345e62b538975f11ae8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Mar 30 14:48:19 2022 -0500
[Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite.
(#10787)
* [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite.
As a follow-up from https://github.com/apache/tvm/pull/9727,
restricting StorageRewrite to only modify flat memory buffers. When
rewriting, the existing algorithm in StorageRewrite flattens N-d
allocations into 1-d allocations, preventing them from being exposed
to the codegen.
* Bugfix, flattening of Allocate/AllocateConst extents
Previously, these were ignored entirely. This worked so long as all
allocations were 1-d, as `StorageRewrite` erroneously flattened merged
arrays into 1-d.
---
src/tir/transforms/storage_flatten.cc | 97 ++++++++++++++++++++++++++++++++++-
src/tir/transforms/storage_rewrite.cc | 77 +++++++++++++++++++++------
2 files changed, 155 insertions(+), 19 deletions(-)
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index 2bfc842..0923517 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator {
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
- return StmtExprMutator::VisitStmt_(op);
+ auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+ return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt),
stmt->condition,
+ stmt->body, stmt->annotations, stmt->span);
}
Stmt VisitStmt_(const AllocateConstNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
- return StmtExprMutator::VisitStmt_(op);
+ auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
+ ObjectRef data_or_idx;
+ if (stmt->data) {
+ data_or_idx = stmt->data.value();
+ } else if (stmt->irmod_storage_idx) {
+ data_or_idx = stmt->irmod_storage_idx.value();
+ } else {
+ LOG(FATAL) << "Neither data array nor data index specified for
allocation of const "
+ << op->buffer_var->name_hint;
+ }
+ return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt),
data_or_idx,
+ stmt->body, stmt->span);
}
Stmt VisitStmt_(const LetStmtNode* op) final {
@@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator {
}
private:
+ // Helper function for visiting Allocate and AllocateConst. If, in
+ // the future, these are updated to hold a buffer (Buffer) object
+ // rather than a buffer_var (Var), this function can be replaced
+ // with a call to GetBufferEntry.
+ template <typename Node>
+ Array<PrimExpr> FlattenExtents(const Node& node) {
+ arith::Analyzer analyzer;
+
+ // If an allocation has extents that match the buffer
+ auto is_compatible_buffer = [&](const Buffer& buffer) {
+ if (buffer->shape.size() != node->extents.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < buffer->shape.size(); i++) {
+ if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
+ return false;
+ }
+ }
+
+ return true;
+ };
+
+ auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
+ if (a.size() != b.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < a.size(); i++) {
+ if (a[i]->value != b[i]->value) {
+ return false;
+ }
+ }
+
+ return true;
+ };
+
+ Array<IntImm> axis_separators;
+ auto it = buffer_var_map_.find(node->buffer_var.get());
+ if (it != buffer_var_map_.end()) {
+ const auto& buffers = it->second;
+ if (buffers.size() == 0) {
+ // No buffers use this allocation, treat as flat and optimize
+ // out later.
+ } else if (buffers.size() == 1) {
+ // Only one buffer uses this allocation, so use its axis
+ // separators.
+ axis_separators = buffers[0]->axis_separators;
+ } else {
+ // Try to find a buffer using this allocation with a matching
+ // shape.
+ Buffer compatible_buffer;
+ for (const auto& buffer : buffers) {
+ if (is_compatible_buffer(buffer)) {
+ ICHECK(!compatible_buffer.defined() ||
+ int_array_equal(compatible_buffer->axis_separators,
buffer->axis_separators))
+ << "Cannot determine axis separators to use when flattening "
+ << node->buffer_var->name_hint
+ << ", multiple buffer objects found with conflicting axis
separators";
+ compatible_buffer = buffer;
+ }
+ }
+ ICHECK(compatible_buffer.defined())
+ << "Cannot determine axis separators to use when flattening "
+ << node->buffer_var->name_hint << ", no buffers found with
matching shape";
+ axis_separators = compatible_buffer->axis_separators;
+ }
+ }
+
+ // Use GetFlattenedBuffer to determine the flattened shape of the
+ // output. We only need the shape and axis separators defined,
+ // everything else can be dummy values.
+ Buffer dummy_buffer =
+ decl_buffer(node->extents, DataType::Float(32), "buffer", "",
axis_separators);
+ return dummy_buffer.GetFlattenedBuffer()->shape;
+ }
+
// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
@@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator {
// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> buffer_var_defines_;
+ // Map from an allocation variable to the buffer(s) that it backs.
+ // Used to track the determine the axis_separators that should be
+ // used for flattening the extents of an AllocateNode.
+ std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual>
buf_map_;
// The extern buffer map, updated to include flattened buffers.
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 0534f31..d1a37e1 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
};
// The scope of each allocation
struct AllocEntry {
+ // The physical dimension of the allocation.
+ size_t num_physical_dimensions{0};
// scope level
size_t level{0};
// allocation stmt
@@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
void VisitStmt_(const AllocateNode* op) final {
size_t level = scope_.size();
const VarNode* buf = op->buffer_var.get();
- alloc_info_[buf].alloc = op;
- alloc_info_[buf].level = level;
+
+ AllocEntry entry;
+ entry.alloc = op;
+ entry.level = level;
+ // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
+ // all allocations specify the extent of physical dimensions, and
+ // is 1 for flat memory spaces.
+ entry.num_physical_dimensions = op->extents.size();
+ alloc_info_[buf] = entry;
+
StmtExprVisitor::VisitStmt_(op);
}
@@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buf);
+
+ ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
+ << "Buffer " << op->buffer->name << " is allocated with "
+ << it->second.num_physical_dimensions
+ << " physical dimensions, but is accessed as having "
+ << op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
StmtEntry e = scope_.back();
scope_.pop_back();
@@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public
StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places
other than store.";
scope_[it->second.level].touched.push_back(buf);
+
+ ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
+ << "Buffer " << op->buffer->name << " is allocated with "
+ << it->second.num_physical_dimensions
+ << " physical dimensions, but is accessed as having "
+ << op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
}
@@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
uint64_t const_nbits{0};
// The storage scope.
StorageScope scope;
+ // The physical dimensionality of the allocations. Since
+ // StorageRewrite is applied after StorageFlatten/FlattenBuffer,
+ // this is size of `AllocateNode::extents`. If moved
+ size_t ndim;
// Allocs that shares this entry.
std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
@@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
// simply use the original allocation.
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return
mul(a, b, span); },
make_const(DataType::Int(32), 1),
e->allocs[0]->extents);
- e->new_alloc =
- Allocate(e->alloc_var, alloc_type, {sz},
e->allocs[0]->condition, Evaluate(0));
+ e->new_alloc = Allocate(e->alloc_var, alloc_type,
e->allocs[0]->extents,
+ e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
- PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return
mul(a, b, span); },
- make_const(DataType::Int(32), 1), op->extents);
+ ICHECK_EQ(op->extents.size(), 1)
+ << "Buffer var " << op->buffer_var->name_hint
+ << " was identified as a re-usable allocation, but has " <<
op->extents.size()
+ << " physical dimensions. "
+ << "Currently, only flat 1-d memory spaces should be
identified as re-usable "
+ "allocations.";
+ PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
@@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {
for (const VarNode* var : it->second.gen) {
ICHECK(alloc_info.count(var));
- const AllocateNode* alloc = alloc_info.at(var).alloc;
+ const AllocEntry& entry = alloc_info.at(var);
+ const AllocateNode* alloc = entry.alloc;
auto storage_scope =
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
StorageEntry* dst_entry = nullptr;
// inplace detection
@@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
- dst_entry = FindAlloc(alloc, thread_scope_, storage_scope);
+ dst_entry =
+ FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
@@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
}
StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
- const StorageScope& scope) {
+ const StorageScope& scope, size_t
num_physical_dimensions) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize()
* op_elem_bits);
+
+ // If the size of the array isn't known at compile-time, it must
+ // have its own allocation with size determined at runtime.
+ bool is_known_size = (const_nbits != 0);
+
+ // Currently, only flat memory spaces can be re-used. Packing
+ // into N-d space (e.g. 2-d texture memory on GPUs) will require
+ // more in-depth algorithms.
+ bool is_flat_memory_space = (num_physical_dimensions == 1);
+
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
- if (scope.tag.length() == 0) {
- if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
- return NewAlloc(op, attach_scope, scope, const_nbits);
- }
- if (const_nbits > 0 && const_nbits <= 32) {
- return NewAlloc(op, attach_scope, scope, const_nbits);
- }
+ bool is_small_array =
+ (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp ||
op->dtype.is_handle() ||
+ (is_known_size && const_nbits <= 32));
+
+ if (is_small_array || !is_flat_memory_space) {
+ return NewAlloc(op, attach_scope, scope, const_nbits);
}
- if (const_nbits != 0) {
+
+ if (is_known_size) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);