This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 659f3b7 [TIR] tir.transform.StorageFlatten refactor (#9091)
659f3b7 is described below
commit 659f3b7fe10e1ca5919c554bc6eec1ba743176e2
Author: Lunderberg <[email protected]>
AuthorDate: Thu Sep 30 19:12:31 2021 -0500
[TIR] tir.transform.StorageFlatten refactor (#9091)
* [TE] Improved flexibility of ArgBinder::BindDLTensor
Allowed a compact DLTensor to bind to a Buffer object that defines
strides, if the strides defined correspond to a compact layout.
* [TIR] Exposed ElemOffset as a member function of BufferNode.
* [TE] Pulled shape determination out of StorageFlattener
Previously, StorageFlattener would determine the shape of a physical
buffer based on the extents of the BufferRealizeNode. Pulled these
out into a separate BufferShapeLegalize pass. After this pass, all
buffers have a shape that matches the buffer realization extents.
* [TE] Refactor stride calculation out of StorageFlattener
Previously, StorageFlattener would handle any attr::dim_align
annotations. Now, this is pulled out into a separate
BufferStrideLegalize pass.
* [TE] Refactor thread scope propagation out of StorageFlattener.
Previously, StorageFlattener would use the scope in IterVar to assign
a scope to allocated buffers, where not otherwise defined. This has
been pulled out into a separate ThreadScopePropagate pass.
* [TE] Refactor buffer bind mapping out of StorageFlattener.
Previously, StorageFlattener would look for `attr::buffer_bind_scope`
to determine if a Buffer object is a view into another buffer, and
would apply that mapping while making the Allocate/Store/Load nodes.
Now, the mapping of buffer binds is pulled out into a separate
BufferStrideUnwrapper pass.
This also resolves an issue in which BufferLoad/BufferStore nodes that
refer to a Buffer defined through `attr::buffer_bind_scope` would
generate Load/Store nodes that point to the linked buffer, rather than
the actual buffer.
* [TIR] Removed checks on buffer->shape.size()
Even after BufferShapeLegalize, rank-zero tensors may have an empty
shape.
* [TIR] Relaxed check on a bufferview's striding.
Original refactoring requiring that a bufferview have no explicit
striding, and instead take the striding from the buffer that it is
viewing. Modified to allow bufferview to specify striding, so long as
it is consistent with the viewed buffer's striding. This reproduces
the behavior of StorageFlatten before the refactoring.
* [TIR] Fixed StorageFlatten test for shape_legalize.
AttrStmtNodes that contain rewritten Buffers need to be rewritten as
well.
* [TIR] Assigned storage scope
The earlier stage of the refactor left a buffer's storage scope
undefined if it's scope was not determined by the IterVar of a loop
containing its allocation. Now, these are explicitly set to
StorageScope::kGlobal, to match the previous behavior of
StorageFlatten.
* Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided
data.
* Added comments in storage_flatten.cc, indicating why buffer_bind_scope
needs special handling.
* Updated comment with a few examples of where compact buffers are
assumed to have no strides defined.
* Updated following @csullivan's comments.
* Added fuzzy mapping to the BufferShapeLegalize.
Maintains earlier behavior of StorageFlatten, which allows buffer
views to be mapped to higher dimension buffers, if the view extent is
1 in each extra dimension.
* Updated BufferShapeLegalize, asserts need to be inside the
buffer_bind_scope.
* Pulled all shape-dependent behavior into BufferShapeLegalize.
Previously, BufferBindUnwrapper passed fuzzy_match=true to
ArgBinder::BindBuffer, which could change the number of dimensions.
Now, all buffer dimensions should be updated prior to
BufferBindUnwrapper, and it is an error to have mismatched dimensions
in BufferBindUnwrapper.
* Added another pass to remove verifiable assert statements.
ArgBinder::BindBuffer inserts these assert statements if they are not
verifiable at the time of substitution. Previously, with one giant
substitution, the assertions were verifiable at that time. After the
refactor, with substitutions done in multiple stages for
shape/stride/buffer_bind_scope, we need to clean up any assertions
that are verifiable after all substitutions have occurred.
* Minor cleanup
- Removed StorageFlattener::BufferEntry::RelIndex, behavior already
handled by BufferShapeLegalize.
- Improved comments and error messages.
- Extracted duplicate behavior in BufferLoad/BufferStore handling in
BufferShapeLegalize.
* Updated to handle BufferRealizeNode with no defined bounds.
* Updated to be less aggressive when checking AssertStmt
A true Assert statement can be removed, but a false Assert statement
requires CFA to give as a compile-time error. Since we only need the
removal of true assert statements, skipping the CFA this time.
---
include/tvm/tir/buffer.h | 8 +
src/tir/ir/buffer.cc | 24 +-
src/tir/transforms/arg_binder.cc | 25 +-
src/tir/transforms/storage_flatten.cc | 1279 ++++++++++++++++++++++++++++-----
4 files changed, 1136 insertions(+), 200 deletions(-)
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 28d202c..f04209d 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -121,6 +121,14 @@ class BufferNode : public Object {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}
+ /*! \brief Determine the offset in the buffer of the given index.
+ *
+ * Returns the buffer offset, in number of elements of type dtype,
+ * without adjusting for number of lanes. (e.g. The number of
+ * float16x4 elements in a buffer of type float16x4.)
+ */
+ PrimExpr ElemOffset(Array<PrimExpr> index) const;
+
static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 8253fa3..24aacc3 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -246,41 +246,41 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer,
const PrimExpr& base) {
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
-inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
- PrimExpr base = n->elem_offset;
+PrimExpr BufferNode::ElemOffset(Array<PrimExpr> index) const {
+ PrimExpr base = this->elem_offset;
arith::Analyzer ana;
- if (n->strides.size() == 0) {
+ if (this->strides.size() == 0) {
// Scalar case
- if (n->shape.size() == 0 && index.size() == 1) {
+ if (this->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImmNode>();
ICHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
- ICHECK_EQ(n->shape.size(), index.size());
+ ICHECK_EQ(this->shape.size(), index.size());
if (index.size() > 0) {
PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
- offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]);
+ offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]);
}
base = base + offset;
}
}
} else {
- ICHECK_EQ(n->strides.size(), index.size());
+ ICHECK_EQ(this->strides.size(), index.size());
if (is_zero(base)) {
- base = MergeMulMod(&ana, index[0] * n->strides[0]);
+ base = MergeMulMod(&ana, index[0] * this->strides[0]);
} else {
- base = MergeMulMod(&ana, base + index[0] * n->strides[0]);
+ base = MergeMulMod(&ana, base + index[0] * this->strides[0]);
}
for (size_t i = 1; i < index.size(); ++i) {
- base = MergeMulMod(&ana, base + index[i] * n->strides[i]);
+ base = MergeMulMod(&ana, base + index[i] * this->strides[i]);
}
}
return base;
}
inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index,
DataType dtype) {
- PrimExpr offset = ElemOffset(n, index);
+ PrimExpr offset = n->ElemOffset(index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
@@ -353,7 +353,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins,
Array<PrimExpr> extents) const
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
- PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins));
+ PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins));
Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 293c990..d3ab32c 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -204,7 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle,
builtin::kArrStrides), nop));
- PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(),
{v_strides});
+ PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(),
{v_strides});
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
@@ -226,7 +226,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a,
b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
- check = IfThenElse(Not(is_null), check, Stmt());
+ check = IfThenElse(Not(v_strides_is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
@@ -239,24 +239,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const
PrimExpr& device_type,
PrimExpr value =
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k),
const_true(1)));
- value = tvm::if_then_else(is_null, stride, value);
+ value = tvm::if_then_else(v_strides_is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = analyzer_.Simplify(stride * buffer->shape[k]);
}
} else {
- std::ostringstream stride_null_err_msg;
- stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
- asserts_.emplace_back(
- AssertStmt(Not(is_null),
tvm::tir::StringImm(stride_null_err_msg.str()), nop));
+ PrimExpr stride_from_shape = 1;
- for (size_t k = 0; k < buffer->strides.size(); ++k) {
+ for (int k = buffer->strides.size() - 1; k >= 0; k--) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
+
+ PrimExpr explicit_stride =
+ cast(buffer->shape[k].dtype(),
+ Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k),
const_true(1)));
+
Bind_(buffer->strides[k],
- cast(buffer->shape[k].dtype(),
- Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k),
const_true(1))),
+ tvm::if_then_else(v_strides_is_null, stride_from_shape,
explicit_stride),
field_name.str(), true);
+
+ stride_from_shape *=
+ cast(buffer->shape[k].dtype(),
+ Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k),
const_true(1)));
}
}
// Byte_offset field.
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index 2c32cc7..6a3ce59 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -50,6 +50,984 @@ using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer. For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+ explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+ IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer) {
+ for (auto kv : extern_buffer_map) {
+ Buffer buf = kv.second;
+ extern_buffers_.insert(buf);
+
+ BufferEntry remap;
+ remap.remap_to = buf;
+ remap.index_offsets = Array<PrimExpr>(buf->shape.size(), 0);
+ remap.in_scope = true;
+ buf_map_[buf] = remap;
+ }
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = var_remap_.find(op);
+ if (it != var_remap_.end()) {
+ return it->second;
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ // BufferRealizeNode for an external buffer serves as an
+ // annotation of the external buffers, and should not be changed.
+ // Instead, verify that the bounds match the external
+ // buffer.
+ if (extern_buffers_.count(op->buffer)) {
+ CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+ << "External buffer realize has mismatched dimension";
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferRealizeNode>();
+ ICHECK(op);
+
+ for (size_t i = 0; i < op->bounds.size(); i++) {
+ PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] ==
op->bounds[i]->extent);
+ std::ostringstream ss;
+ ss << "Dim " << i << " of external buffer " << op->buffer->name << "
has shape "
+ << op->buffer->shape[i] << ", but is only realized for extent " <<
op->bounds[i]->extent;
+ if (auto eq_int = eq.as<IntImmNode>()) {
+ ICHECK(eq_int->value) << ss.str();
+ } else {
+ stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+ }
+ }
+ return stmt;
+ }
+
+ // Compute the new buffer shape, new realization bounds, and the
+ // offsets to be applied to buffer access.
+ Array<PrimExpr> realized_shape;
+ Array<PrimExpr> index_offsets;
+ Array<Range> new_bounds;
+ for (size_t i = 0; i < op->bounds.size(); i++) {
+ const Range& bound = op->bounds[i];
+ realized_shape.push_back(bound->extent);
+ index_offsets.push_back(bound->min);
+ new_bounds.push_back({0, bound->extent});
+ }
+
+ if (op->buffer->shape.size()) {
+ ICHECK_EQ(op->buffer->shape.size(), realized_shape.size())
+ << "Inconsistency between dimension of buffer " << op->buffer
+ << " and dimension of its realized bounds.";
+ }
+
+ Buffer key = op->buffer;
+
+ Buffer buf = op->buffer;
+ auto write_ptr = buf.CopyOnWrite();
+ write_ptr->shape = realized_shape;
+
+ {
+ BufferEntry remap;
+ remap.remap_to = buf;
+ remap.index_offsets = index_offsets;
+ remap.in_scope = true;
+ buf_map_[key] = remap;
+ }
+
+ Stmt stmt = BufferRealize(buf, new_bounds, op->condition,
this->VisitStmt(op->body), op->span);
+
+ buf_map_.at(key).in_scope = false;
+
+ return stmt;
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+ ICHECK(op);
+
+ auto it = buf_map_.find(op->buffer);
+ if (it != buf_map_.end()) {
+ const BufferEntry& entry = it->second;
+ ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+
+ BufferStore updated = GetRef<BufferStore>(op);
+ auto write_ptr = updated.CopyOnWrite();
+ write_ptr->indices = update_indices(op->indices, entry.index_offsets);
+ write_ptr->buffer = entry.remap_to;
+ stmt = updated;
+ }
+
+ return stmt;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+ ICHECK(op);
+
+ auto it = buf_map_.find(op->buffer);
+ if (it != buf_map_.end()) {
+ const BufferEntry& entry = it->second;
+ ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+
+ BufferLoad updated = GetRef<BufferLoad>(op);
+ auto write_ptr = updated.CopyOnWrite();
+ write_ptr->indices = update_indices(op->indices, entry.index_offsets);
+ write_ptr->buffer = entry.remap_to;
+ expr = updated;
+ }
+
+ return expr;
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->node->IsInstance<tir::BufferNode>()) {
+ // Visit body before checking internal_buf_map_, because we
+ // don't know if the BufferNode needs to be changed until we
+ // look in the body for a BufferRealizeNode with different
+ // extents.
+ Stmt body = this->VisitStmt(op->body);
+
+ Buffer buffer = Downcast<tir::Buffer>(op->node);
+ auto it = buf_map_.find(buffer);
+ if (it != buf_map_.end()) {
+ buffer = it->second.remap_to;
+ return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+ }
+ return AttrStmt(buffer, op->attr_key, op->value, body);
+
+ } else if (op->attr_key == attr::buffer_bind_scope) {
+ return HandleBufferBindScope(op);
+ }
+
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ private:
+ // Any buffers that give views into a resized buffer should be
+ // updated, both to refer to the resized buffer and to have the view
+ // window updated. For example, suppose B1 is a 1-D buffer of size
+ // 100 which is only realized on the range (10,50), and buffer V1 is
+ // a view into B1[25:35]. When B1 is replaced with B2, a buffer of
+ // size 40 realized on the range (0,40), V1 must be replaced to be a
+ // view into B2[15:25].
+ Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+ ICHECK_EQ(arr.size(), 2U);
+ Buffer buffer = Downcast<Buffer>(arr[0]);
+ ICHECK(buffer.defined());
+ Buffer target = Downcast<Buffer>(arr[1]);
+ ICHECK(target.defined());
+
+ auto it = buf_map_.find(target);
+ ICHECK(it != buf_map_.end()) << "attr::buffer_bind_scope target " <<
target << " not in scope.";
+ const BufferEntry& target_remap = it->second;
+
+ ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
+ << " to the out-of-scope buffer " <<
target_remap.remap_to->name;
+
+ Call tuple = Downcast<Call>(op->value);
+ ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
+
+ Array<PrimExpr> new_tuple_args;
+ Array<PrimExpr> realized_begins;
+ Array<PrimExpr> view_shape;
+ ICHECK_EQ(tuple->args.size(), target_remap.index_offsets.size() * 2)
+ << "attr::buffer_bind_scope to define " << buffer << " as a view into
" << target
+ << " does match dimensionality of " << target;
+ for (size_t i = 0; i < target_remap.index_offsets.size(); i++) {
+ PrimExpr parent_begin = tuple->args[2 * i];
+ PrimExpr view_extent = tuple->args[2 * i + 1];
+ // Offset the begin of the buffer view by the offset of the target
buffer.
+ new_tuple_args.push_back(parent_begin - target_remap.index_offsets[i]);
+ // Keep the extent of the buffer view the same.
+ new_tuple_args.push_back(view_extent);
+ // Use the extent of the buffer view to define the buffer view's shape.
+ view_shape.push_back(view_extent);
+ // Within the buffer view, indices start at 0.
+ realized_begins.push_back(0);
+ }
+
+ // If a view is binding to a buffer of a higher dimensionality,
+ // then the leading dimensions should be padded out with shape of
+ // 1.
+ ICHECK_GE(view_shape.size(), buffer->shape.size())
+ << "Cannot bind " << buffer << " to a shape of lower dimension.";
+ if (view_shape.size() > buffer->shape.size()) {
+ size_t diff = view_shape.size() - buffer->shape.size();
+ Array<PrimExpr> padded_shape;
+ for (size_t i = 0; i < diff; i++) {
+ padded_shape.push_back(1);
+ }
+ for (auto dim : buffer->shape) {
+ padded_shape.push_back(dim);
+ }
+ view_shape = std::move(padded_shape);
+ }
+
+ // If a buffer has strides defined, and is being remapped into a
+ // shape with additional dimensions, then define dummy values for
+ // the strides.
+ Array<PrimExpr> realized_strides = buffer->strides;
+ if ((realized_strides.size() > 0) && (realized_strides.size() !=
view_shape.size())) {
+ ICHECK_GE(view_shape.size(), realized_strides.size())
+ << "Cannot bind the strides of " << buffer << " to a shape of lower
dimension";
+ size_t diff = view_shape.size() - buffer->strides.size();
+
+ Array<PrimExpr> updated_strides;
+ for (size_t i = 0; i < diff; i++) {
+ updated_strides.push_back(Var("stride", buffer->shape[0].dtype()));
+ }
+ for (auto stride : buffer->strides) {
+ updated_strides.push_back(stride);
+ }
+ realized_strides = updated_strides;
+ }
+
+ Buffer key = buffer;
+
+ auto write_ptr = buffer.CopyOnWrite();
+ write_ptr->shape = view_shape;
+ write_ptr->strides = realized_strides;
+
+ {
+ BufferEntry remap;
+ remap.index_offsets = realized_begins;
+ remap.remap_to = buffer;
+ remap.in_scope = true;
+ buf_map_[key] = remap;
+ }
+
+ // Define remappings of any Variables referencing Buffer internals
+ // (e.g. Store/Load nodes). Passing fuzzy_match=true allows the
+ // remapped buffer to have a number of dimensions.
+ ArgBinder binder(&var_remap_);
+ binder.BindBuffer(key, buffer, key->name, true);
+
+ Stmt body = this->VisitStmt(op->body);
+ body = MergeNest(binder.asserts(), body);
+ body = MergeNest(binder.init_nest(), body);
+
+ Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to},
op->attr_key,
+ Call(tuple->dtype, tuple->op, new_tuple_args,
tuple->span), body);
+
+ for (const Var& v : binder.defs()) {
+ var_remap_.erase(v.get());
+ }
+
+ buf_map_.at(key).in_scope = false;
+ return stmt;
+ }
+
+ Array<PrimExpr> update_indices(const Array<PrimExpr>& indices, const
Array<PrimExpr>& offsets) {
+ // offsets come from BufferRealizeNode::bounds, which is allowed
+ // to be empty to indicate realization of the full shape of the
+ // buffer. In that case, the indices do not need to be modified,
+ // but may need to be extended with leading zeroes.
+ if (offsets.size() == 0) {
+ return indices;
+ }
+
+ ICHECK_GE(offsets.size(), indices.size())
+ << "Cannot bind buffer to a shape of lower dimension.";
+
+ Array<PrimExpr> new_indices;
+
+ // Pad leading indices with zero, matching the "fuzzy_match"
+ // behavior from ArgBinder::BindBuffer.
+ size_t diff = offsets.size() - indices.size();
+ for (size_t i = 0; i < diff; i++) {
+ new_indices.push_back(0);
+ }
+
+ // Offset indices used to access buffers of a reduced size.
+ for (size_t i = 0; i < indices.size(); i++) {
+ PrimExpr offset = offsets[i + diff];
+ new_indices.push_back(indices[i] - offset);
+ }
+
+ return new_indices;
+ }
+
+ std::unordered_map<const VarNode*, PrimExpr> var_remap_;
+
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
+
+ struct BufferEntry {
+ Buffer remap_to;
+ Array<PrimExpr> index_offsets;
+ bool in_scope;
+ };
+
+ std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual>
buf_map_;
+
+ IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Apply dimension alignment restrictions
+ *
+ * Buffers annotated with attr::buffer_dim_align may need to have
+ * strides defined such that they are no longer in a compact shape.
+ * After this pass, buffers have stride definitions to include these
+ * alignment restrictions, and attr::buffer_dim_align annotations have
+ * been removed.
+ */
+class BufferStrideLegalize : public StmtExprMutator {
+ public:
+ explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
+ IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer) {
+ for (auto kv : extern_buffer_map) {
+ Buffer buf = kv.second;
+ Buffer with_strides = WithStrides(buf);
+ {
+ BufferEntry entry;
+ entry.remap_to = with_strides;
+ entry.in_scope = true;
+ entry.is_external = true;
+ buf_map_[buf] = entry;
+ }
+ updated_extern_buffer_map_.Set(kv.first, with_strides);
+ }
+ }
+
+ Map<Var, Buffer> UpdatedExternBufferMap() const { return
updated_extern_buffer_map_; }
+
+ Buffer WithStrides(Buffer buf) {
+ auto it = buf_map_.find(buf);
+ if (it != buf_map_.end()) {
+ const BufferEntry& entry = it->second;
+ ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
+ return entry.remap_to;
+ }
+
+ if (buf->strides.size()) {
+ ICHECK_EQ(buf->strides.size(), buf->shape.size())
+ << "Buffer " << buf << " has inconsistent strides/shape.";
+ return buf;
+ }
+
+ // Keeping this to have matched behavior to previous version.
+ // There are many parts of the codebase that assume that a strided
+ // array cannot be compact. For example, ArgBinder::BindBuffer
+ // and tir.Specialize.
+ if (dim_align_.count(buf) == 0) {
+ return buf;
+ }
+
+ // Can't define the strides for a buffer without a known shape.
+ Array<PrimExpr> shape = buf->shape;
+ if (shape.size() == 0) {
+ return buf;
+ }
+
+ std::vector<PrimExpr> rstrides;
+ const std::vector<DimAlignInfo>& avec = dim_align_[buf];
+ int first_dim = 0;
+ PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
+ for (size_t i = shape.size(); i != 0; --i) {
+ size_t dim = i - 1;
+ if (dim < avec.size() && avec[dim].align_factor != 0) {
+ PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+ PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
+ stride = stride + indexmod(factor + offset - indexmod(stride, factor),
factor);
+ stride = bound_analyzer_->Simplify(stride);
+ }
+ rstrides.push_back(stride);
+ stride = stride * shape[dim];
+ }
+
+ auto ptr = buf.CopyOnWrite();
+ ptr->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
+
+ return buf;
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::buffer_dim_align) {
+ auto buffer = Downcast<tir::Buffer>(op->node);
+ const CallNode* tuple = op->value.as<CallNode>();
+ ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+ auto& vinfo = dim_align_[buffer];
+ int dim = tuple->args[0].as<IntImmNode>()->value;
+ if (static_cast<size_t>(dim) >= vinfo.size()) {
+ vinfo.resize(dim + 1);
+ }
+ vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+ vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
+
+ return this->VisitStmt(op->body);
+ } else if (op->attr_key == attr::buffer_bind_scope) {
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+ ICHECK_EQ(arr.size(), 2U);
+ Buffer source = Downcast<Buffer>(arr[0]);
+ Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
+ Buffer source_with_strides = WithStrides(source);
+
+ {
+ BufferEntry entry;
+ entry.remap_to = source_with_strides;
+ entry.in_scope = true;
+ entry.is_external = false;
+ buf_map_[source] = entry;
+ }
+
+ Stmt body = this->VisitStmt(op->body);
+
+ return AttrStmt(Array<ObjectRef>{source_with_strides,
target_with_strides}, op->attr_key,
+ op->value, body, op->span);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ Buffer key = op->buffer;
+ Buffer with_strides = WithStrides(op->buffer);
+ {
+ BufferEntry entry;
+ entry.remap_to = with_strides;
+ entry.in_scope = true;
+ entry.is_external = false;
+ buf_map_[key] = entry;
+ }
+
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+ buf_map_[key].in_scope = false;
+ op = stmt.as<BufferRealizeNode>();
+ ICHECK(op);
+
+ return BufferRealize(with_strides, op->bounds, op->condition, op->body,
op->span);
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+
+ auto it = buf_map_.find(op->buffer);
+ ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " <<
op->buffer;
+ const BufferEntry& e = it->second;
+ ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope";
+
+ return BufferLoad(e.remap_to, op->indices, op->span);
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+
+ auto it = buf_map_.find(op->buffer);
+ ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " <<
op->buffer;
+ const BufferEntry& e = it->second;
+ ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of
scope";
+
+ return BufferStore(e.remap_to, op->value, op->indices, op->span);
+ }
+
+ private:
+ Map<Var, Buffer> updated_extern_buffer_map_;
+
+ struct DimAlignInfo {
+ int align_factor{0};
+ int align_offset{0};
+ };
+
+ // Dimension alignment
+ std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash,
ObjectPtrEqual> dim_align_;
+
+ struct BufferEntry {
+ Buffer remap_to;
+ bool in_scope;
+ bool is_external;
+ };
+
+ std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual>
buf_map_;
+
+ IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+/* Use the scope of IterVar to determine storage scope.
+ *
+ * For buffers that do not have an explicit storage scope defined, a
+ * reasonable storage scope may be defined based on the thread scope
+ * that contains the buffer's allocation. All other buffers without a
+ * scope are assigned to global scope.
+ */
+class ThreadScopePropagate : public StmtExprMutator {
+ public:
+ explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
+ // External buffers shouldn't be overwritten, even if they have a
+ // BufferRealizeNode.
+ for (auto kv : extern_buffer_map) {
+ external_buffers_.insert(kv.second);
+ }
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = buf_remap_.find(GetRef<Var>(op));
+ if (it != buf_remap_.end()) {
+ return it->second->data;
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+ << "StorageFlattener assumes that all buffers have accurate strides, "
+ << "and all buffer_dim_align annotations are removed. "
+ << "Please run BufferStrideLegalize first.";
+
+ if (op->attr_key == attr::thread_extent) {
+ IterVar iv = Downcast<IterVar>(op->node);
+ ThreadScope ts = ThreadScope::Create(iv->thread_tag);
+ curr_thread_scope_.push_back(ts);
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ curr_thread_scope_.pop_back();
+ return stmt;
+ } else if (op->attr_key == attr::buffer_bind_scope) {
+ return HandleBufferBindScope(op);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ Var old_var = op->buffer->data;
+
+ // Don't remap buffers that already have an explicit scope,
+ // or external buffers.
+ std::string str_scope = GetPtrStorageScope(old_var);
+ if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ ICHECK_EQ(buf_remap_.count(old_var), 0)
+ << "Buffer var " << op->buffer->data << " appears in multiple
BufferRealize nodes";
+
+ StorageScope skey;
+ if (curr_thread_scope_.size() == 0) {
+ skey.rank = StorageRank::kGlobal;
+ } else {
+ skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
+ }
+
+ auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr_type);
+ Var new_var(old_var->name_hint, PointerType(ptr_type->element_type,
skey.to_string()),
+ old_var->span);
+
+ Buffer buf = op->buffer;
+ buf.CopyOnWrite()->data = new_var;
+
+ buf_remap_[old_var] = buf;
+
+ Stmt body = this->VisitStmt(op->body);
+ return BufferRealize(buf, op->bounds, op->condition, body, op->span);
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+ ICHECK(op);
+
+ auto it = buf_remap_.find(op->buffer->data);
+ if (it != buf_remap_.end()) {
+ return BufferLoad(it->second, op->indices, op->span);
+ } else {
+ return expr;
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+ ICHECK(op);
+
+ auto it = buf_remap_.find(op->buffer->data);
+ if (it != buf_remap_.end()) {
+ return BufferStore(it->second, op->value, op->indices, op->span);
+ } else {
+ return stmt;
+ }
+ }
+
+ private:
+ // If the rewritten buffers are part of a buffer_bind_scope, either
+ // as the buffer view or as the the buffer being viewed, then the
+ // buffer_bind_scope must be rewritten to refer to the updated
+ // buffers.
+ Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+ ICHECK_EQ(arr.size(), 2U);
+ Buffer buffer = Downcast<Buffer>(arr[0]);
+ ICHECK(buffer.defined());
+ Buffer target = Downcast<Buffer>(arr[1]);
+ ICHECK(target.defined());
+
+ bool needs_rewrite = false;
+
+ {
+ auto it = buf_remap_.find(buffer->data);
+ if (it != buf_remap_.end()) {
+ needs_rewrite = true;
+ buffer = it->second;
+ }
+ }
+
+ {
+ auto it = buf_remap_.find(target->data);
+ if (it != buf_remap_.end()) {
+ needs_rewrite = true;
+ target = it->second;
+ }
+ }
+
+ if (needs_rewrite) {
+ Stmt body = this->VisitStmt(op->body);
+ return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key,
op->value, body);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
+
+ // The current thread scope.
+ std::vector<ThreadScope> curr_thread_scope_;
+};
+
+/* Map buffer binds to their source buffer
+ *
+ * Buffers defined using an attr::buffer_bind_scope annotation are
+ * views into some linked buffer, potentially into some restricted
+ * subregion of that buffer. This pass identifies such buffers, then
+ * rewrites all access of the bound buffers to be access into the
+ * linked buffer.
+ */
+class BufferBindUnwrapper : public StmtExprMutator {
+ public:
+ explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
+ IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer) {
+ for (auto kv : extern_buffer_map) {
+ BufferEntry e;
+ e.buffer = kv.second;
+ e.external = true;
+ buf_map_[kv.second.get()] = std::move(e);
+ }
+ }
+
+ Stmt VisitStmt_(const StoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<StoreNode>();
+ auto it = var_remap_.find(op->buffer_var.get());
+ if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+ // TODO(Lunderberg): Change from warning to error once all mixed
+ // use of physical/logical layouts is removed.
+ DLOG(WARNING) << op->buffer_var << " was declared as buffer
(buffer_bind_scope), "
+ << "but is accessed as a pointer (StoreNode).";
+
+ ICHECK(it->second.as<VarNode>());
+ Var new_buf_var = Downcast<Var>(it->second);
+ return Store(new_buf_var, op->value, op->index, op->predicate);
+ } else {
+ return stmt;
+ }
+ }
+
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<LoadNode>();
+ auto it = var_remap_.find(op->buffer_var.get());
+ if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
+ // TODO(Lunderberg): Change from warning to error once all mixed
+ // use of physical/logical layouts is removed.
+ DLOG(WARNING) << op->buffer_var << " was declared as buffer
(buffer_bind_scope), "
+ << "but is accessed as a pointer (LoadNode).";
+
+ ICHECK(it->second.as<VarNode>());
+ Var new_buf_var = Downcast<Var>(it->second);
+ return Load(op->dtype, new_buf_var, op->index, op->predicate);
+ } else {
+ return expr;
+ }
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+ << "BufferBindUnwrapper assumes that all buffers have accurate
strides, "
+ << "and all buffer_dim_align annotations are removed. "
+ << "Please run BufferStrideLegalize first.";
+
+ if (op->attr_key == attr::buffer_bind_scope) {
+ return HandleBufferBindScope(op);
+ } else {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = var_remap_.find(op);
+ if (it != var_remap_.end()) {
+ return it->second;
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr>
begins,
+ Array<PrimExpr> extents) {
+ ICHECK_EQ(begins.size(), extents.size());
+
+ if (begins.size() == 0) {
+ return indices;
+ }
+
+ ICHECK_EQ(begins.size(), indices.size());
+
+ Array<PrimExpr> out;
+ for (size_t i = 0; i < begins.size(); i++) {
+ out.push_back(begins[i] + indices[i]);
+ }
+ return out;
+ }
+
+ Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins,
Array<PrimExpr> extents) {
+ ICHECK_EQ(begins.size(), extents.size());
+
+ if (begins.size() == 0) {
+ return bounds;
+ }
+
+ ICHECK_EQ(begins.size(), bounds.size());
+
+ Array<Range> out;
+ for (size_t i = 0; i < begins.size(); i++) {
+ out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i],
bounds[i]->extent));
+ }
+ return out;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+
+ auto it = buf_map_.find(op->buffer.get());
+ ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " <<
op->buffer;
+ const BufferEntry& e = it->second;
+ ICHECK(e.in_scope) << "Cannot read from buffer " << op->buffer << ", out
of scope.";
+
+ if (e.remap) {
+ return BufferLoad(e.remap->target,
+ remap_indices(op->indices, e.remap->begins,
e.remap->extents), op->span);
+ } else {
+ return expr;
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+
+ auto it = buf_map_.find(op->buffer.get());
+ ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " <<
op->buffer;
+ const BufferEntry& e = it->second;
+ ICHECK(e.in_scope) << "Cannot write to buffer" << op->buffer << ", out of
scope.";
+
+ if (e.remap) {
+ return BufferStore(e.remap->target, op->value,
+ remap_indices(op->indices, e.remap->begins,
e.remap->extents), op->span);
+ } else {
+ return stmt;
+ }
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ const auto& key = op->buffer.get();
+
+ bool is_external = false;
+
+ if (buf_map_.count(key)) {
+ ICHECK(buf_map_.at(key).external)
+ << "BufferRealize node for internal buffer " << op->buffer << "
occurred multiple times.";
+
+ is_external = true;
+ } else {
+ BufferEntry e;
+ e.bounds = op->bounds;
+ e.buffer = op->buffer;
+ buf_map_[key] = std::move(e);
+ }
+
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+
+ if (is_external) {
+ buf_map_[key].in_scope = false;
+ }
+
+ return stmt;
+ }
+
+ Stmt VisitStmt_(const PrefetchNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<PrefetchNode>();
+ ICHECK(op != nullptr);
+
+ const auto& key = op->buffer.get();
+ auto it = buf_map_.find(key);
+ ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
+ const BufferEntry& e = it->second;
+
+ ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
+ ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
+ << "Prefetch dim should be the same as buffer dim";
+
+ if (e.remap) {
+ return Prefetch(e.remap->target, remap_bounds(op->bounds,
e.remap->begins, e.remap->extents),
+ op->span);
+ } else {
+ return stmt;
+ }
+ }
+
+ private:
+ // Read the mapping from a buffer view to the actual buffer. This
+ // allows all later BufferStore/BufferLoad nodes to reference the
+ // actual buffer, rather than the buffer view.
+ Stmt HandleBufferBindScope(const AttrStmtNode* op) {
+ // Unpack information from Attribute node
+ RemapInfo remap;
+
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+ ICHECK_EQ(arr.size(), 2U);
+ const Buffer source = Downcast<Buffer>(arr[0]);
+ ICHECK(source.defined());
+ remap.target = Downcast<Buffer>(arr[1]);
+ ICHECK(remap.target.defined());
+ const CallNode* tuple = op->value.as<CallNode>();
+ ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
+
+ for (size_t i = 0; i < tuple->args.size(); i += 2) {
+ remap.begins.push_back(tuple->args[i]);
+ remap.extents.push_back(tuple->args[i + 1]);
+ }
+
+ // Determine bounds in the target buffer
+ auto it = buf_map_.find(remap.target.get());
+ ICHECK(it != buf_map_.end()) << "Cannot define " << source << " as a view
into " << remap.target
+ << ", " << remap.target << " was not
defined.";
+ const BufferEntry& target_info = it->second;
+ ICHECK(target_info.in_scope) << "Cannot define " << source << " as a view
into " << remap.target
+ << ", " << remap.target << " is out of
scope.";
+ ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
+ << "Incorrect number of arguments in buffer_bind_scope attribute. "
+ << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N,
extent_N).";
+
+ if (target_info.bounds.size() > 0) {
+ Array<PrimExpr> mapped_begins;
+ for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
+ mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
+ }
+ remap.begins = std::move(mapped_begins);
+ }
+
+ ICHECK(target_info.remap == nullptr)
+ << "buffer_bind_scope defines " << source << " as a view into " <<
remap.target
+ << ", which is itself a buffer view. "
+ << "Indirect remapping not currently supported.";
+
+ for (size_t i = 0; i < remap.begins.size(); i++) {
+ remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
+ remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
+ }
+
+ // Add a buffer remap entry
+ {
+ BufferEntry source_info;
+ source_info.buffer = source;
+ source_info.remap = std::make_unique<RemapInfo>(remap);
+
+ buf_map_[source.get()] = std::move(source_info);
+ }
+
+ // Define remappings of any remaining Variables (e.g. Store/Load nodes).
+ ArgBinder binder(&var_remap_);
+
+ // Define a view that represents the source's view into the target
+ // buffer. This Buffer object is only used to define the mapping
+ // to the target buffer, and never actually appears in the TIR
+ // graph.
+ Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
+ if (source->strides.size() == 0) {
+ ICHECK_EQ(view->strides.size(), 0U)
+ << "Cannot bind a compact buffer " << source << " to a strided
buffer " << view
+ << " with strides " << view->strides;
+ } else {
+ // Add explicit strides to the view, in order to bind to
source.strides[i].
+ view = view.MakeStrideView();
+ }
+
+ // Bind any variables that reference the view (e.g. elem_offset,
+ // strides, shape). Pass fuzzy_match=false, because all shape
+ // transformations should have been handled in
+ // BufferShapeLegalize.
+ binder.BindBuffer(source, view, source->name, false);
+
+ // Apply the remaps
+ Stmt body = op->body;
+ body = MergeNest(binder.asserts(), body);
+ body = MergeNest(binder.init_nest(), body);
+ body = this->VisitStmt(body);
+ // remove the binds
+ for (const Var& v : binder.defs()) {
+ var_remap_.erase(v.get());
+ }
+ return body;
+ }
+
+ struct RemapInfo {
+ Buffer target;
+ Array<PrimExpr> begins;
+ Array<PrimExpr> extents;
+ };
+
+ // The buffer entry in the flatten map
+ struct BufferEntry {
+ // The storage buffer
+ Buffer buffer;
+ // the bounds of realization, can be null, means everything
+ Region bounds;
+ // Whether the buffer is external
+ bool external{false};
+ // Whether we are within the allocation scope of the buffer.
+ bool in_scope{true};
+
+ // The buffer to which the storage buffer should be remapped.
+ std::unique_ptr<RemapInfo> remap{nullptr};
+ };
+
+ // The buffer assignment map
+ // Variable remap
+ std::unordered_map<const VarNode*, PrimExpr> var_remap_;
+ // Buffer map
+ std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
+ // Analyzer for the variable bounds, used to simplify the bounds populator.
We really need the
+ // analyzer from it. However
+ IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
class StorageFlattener : public StmtExprMutator {
public:
explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map, int
cache_line_size,
@@ -78,6 +1056,15 @@ class StorageFlattener : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
+ ICHECK_NE(op->attr_key, attr::buffer_dim_align)
+ << "StorageFlattener assumes that all buffers have accurate strides, "
+ << "and all buffer_dim_align annotations are removed. "
+ << "Please run BufferStrideLegalize first.";
+
+ ICHECK_NE(op->attr_key, attr::buffer_bind_scope)
+ << "StorageFlattener assumes that all buffer binds have already been
applied. "
+ << "Please run BufferBindUnwrapper first.";
+
if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<tir::BufferNode>()) {
auto buffer = Downcast<tir::Buffer>(op->node);
Stmt body = this->VisitStmt(op->body);
@@ -85,27 +1072,6 @@ class StorageFlattener : public StmtExprMutator {
ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " <<
buffer;
body = AttrStmt(it->second.buffer->data, op->attr_key, op->value,
std::move(body));
return body;
- } else if (op->attr_key == attr::thread_extent) {
- IterVar iv = Downcast<IterVar>(op->node);
- ThreadScope ts = ThreadScope::Create(iv->thread_tag);
- curr_thread_scope_.push_back(ts);
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- curr_thread_scope_.pop_back();
- return stmt;
- } else if (op->attr_key == attr::buffer_bind_scope) {
- return HandleBufferBindScope(op);
- } else if (op->attr_key == attr::buffer_dim_align) {
- auto buffer = Downcast<tir::Buffer>(op->node);
- const CallNode* tuple = op->value.as<CallNode>();
- ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
- auto& vinfo = dim_align_[buffer];
- int dim = tuple->args[0].as<IntImmNode>()->value;
- if (static_cast<size_t>(dim) >= vinfo.size()) {
- vinfo.resize(dim + 1);
- }
- vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
- vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
- return this->VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
@@ -121,9 +1087,9 @@ class StorageFlattener : public StmtExprMutator {
ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- ICHECK(!e.released) << "Read a buffer that is already out of scope";
+ ICHECK(e.in_scope) << "Cannot write to " << op->buffer << ", out of
scope.";
- Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value);
+ Stmt body = e.buffer.vstore(op->indices, op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(std::make_pair(e.buffer->data,
e.buffer->shape));
}
@@ -141,26 +1107,31 @@ class StorageFlattener : public StmtExprMutator {
const auto& key = op->buffer;
if (buf_map_.count(key)) {
- ICHECK(buf_map_.at(key).external);
+ ICHECK(buf_map_.at(key).external)
+ << "BufferRealize for internal buffer " << op->buffer << " appears
multiple times.";
return this->VisitStmt(op->body);
} else {
// create a buffer entry
BufferEntry e;
- e.bounds = op->bounds;
- Array<PrimExpr> shape;
- for (auto r : e.bounds) {
- shape.push_back(r->extent);
- }
- // deduce current storage scope.
- StorageScope skey;
- std::string strkey = GetPtrStorageScope(op->buffer->data);
- if (strkey.length() == 0) {
- if (curr_thread_scope_.size() != 0) {
- skey.rank =
runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
- }
- } else {
- skey = StorageScope::Create(strkey);
+
+ ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+ << "Inconsistent buffer shape and realization shape for " <<
op->buffer;
+
+ for (size_t i = 0; i < op->bounds.size(); i++) {
+ const auto& bound = op->bounds[i];
+ const auto& dim_size = op->buffer->shape[i];
+ ICHECK(is_zero(bound_analyzer_->Simplify(bound->min)))
+ << "Buffer " << op->buffer << " has realization bounds that do not
start at zero. "
+ << "Please run BufferShapeLegalize first.";
+ ICHECK(is_one(bound_analyzer_->Simplify(bound->extent == dim_size)))
+ << "Buffer " << op->buffer
+ << " has realization extent that does not match its size. "
+ "Please run BufferShapeLegalize first.";
}
+
+ Array<PrimExpr> shape = op->buffer->shape;
+ StorageScope skey =
StorageScope::Create(GetPtrStorageScope(op->buffer->data));
+
// use small alignment for small arrays
auto dtype = op->buffer->dtype;
int32_t const_size = AllocateNode::constant_allocation_size(shape);
@@ -173,36 +1144,14 @@ class StorageFlattener : public StmtExprMutator {
<< "Allocation exceed bound of memory tag " << skey.to_string();
}
}
- Array<PrimExpr> strides;
- if (dim_align_.count(key) != 0 && shape.size() != 0) {
- std::vector<PrimExpr> rstrides;
- const std::vector<DimAlignInfo>& avec = dim_align_[key];
- int first_dim = 0;
- PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
- for (size_t i = shape.size(); i != 0; --i) {
- size_t dim = i - 1;
- if (dim < avec.size() && avec[dim].align_factor != 0) {
- PrimExpr factor = make_const(stride.dtype(),
avec[dim].align_factor);
- PrimExpr offset = make_const(stride.dtype(),
avec[dim].align_offset);
- stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
- stride = bound_analyzer_->Simplify(stride);
- }
- rstrides.push_back(stride);
- stride = stride * shape[dim];
- }
- strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
- }
+ Array<PrimExpr> strides = op->buffer->strides;
- auto* ptr_type = op->buffer->data->type_annotation.as<PointerTypeNode>();
- ICHECK(ptr_type);
- auto new_var =
- Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type,
skey.to_string()));
- e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides,
PrimExpr(), op->buffer->name,
- align, 0, kDefault);
+ e.buffer = Buffer(op->buffer->data, op->buffer->dtype, shape, strides,
PrimExpr(),
+ op->buffer->name, align, 0, kDefault);
buf_map_[key] = e;
Stmt body = this->VisitStmt(op->body);
- buf_map_[key].released = true;
+ buf_map_[key].in_scope = false;
Stmt ret;
DataType storage_type = e.buffer->dtype;
@@ -264,12 +1213,12 @@ class StorageFlattener : public StmtExprMutator {
auto it = buf_map_.find(key);
ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- ICHECK(!e.released) << "Read a buffer that is already out of scope";
+ ICHECK(e.in_scope) << "Cannot read to " << op->buffer << ", out of scope.";
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(std::make_pair(e.buffer->data,
e.buffer->shape));
}
- return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype);
+ return e.buffer.vload(op->indices, e.buffer->dtype);
}
Stmt VisitStmt_(const PrefetchNode* op) final {
@@ -282,7 +1231,7 @@ class StorageFlattener : public StmtExprMutator {
ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
- ICHECK(!e.released) << "Read a buffer that is already out of scope";
+ ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of
scope.";
ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
<< "Prefetch dim should be the same as buffer dim";
@@ -315,7 +1264,7 @@ class StorageFlattener : public StmtExprMutator {
if (i < starts) {
stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt);
} else {
- PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
+ PrimExpr load = e.buffer.vload(args, e.buffer->dtype);
PrimExpr address = Call(DataType::Handle(), builtin::address_of(),
{load});
PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(),
{address, 0, 3, 1});
stmt = Evaluate(prefetch);
@@ -327,107 +1276,24 @@ class StorageFlattener : public StmtExprMutator {
}
PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
- LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc.";
+ LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc. "
+ << "Please run SchedulePostProcToPrimFunc first.";
return PrimExpr();
}
Stmt VisitStmt_(const ProducerStoreNode* op) final {
- LOG(FATAL) << "Cannot handle Provide "
- << " please run SchedulePostProcToPrimFunc first";
+ LOG(FATAL) << "ProducerStore cannot appear in a valid TIR PrimFunc. "
+ << "Please run SchedulePostProcToPrimFunc first.";
return Stmt();
}
Stmt VisitStmt_(const ProducerRealizeNode* op) final {
- LOG(FATAL) << "Cannot handle Realize "
- << " please run SchedulePostProcToPrimFunc first";
+ LOG(FATAL) << "ProducerRealize cannot appear in a valid TIR PrimFunc. "
+ << "Please run SchedulePostProcToPrimFunc first.";
return Stmt();
}
private:
- // The specific tensor data layout is not determined before
- // StorageFlatten pass. We use buffer_bind_scope
- // to specify before hand we want to bind a subregion
- // of tensor to a symbolic buffer, which get used in extern.
- //
- // Example:
- //
- // realize A in range [i*4, extent=10) {
- // bind Ab to A in [i*4+1, extent=4) {
- // call_func(Ab.ptr, Ab.shape[0])
- // }
- // }
- //
- // After StorageFlatten
- //
- // alloc A[10]
- // call(A + 1, 4)
- //
- // Buffer is a protocol to declare specific
- // data layout and shape we expect.
- // So this function need to check:
- // - If the bind range is within the realize range
- // - If we can match the requirement of buffer
- // - Remap variables such as Ab.ptr to the actual value.
- //
- // Here are a few possible failure cases:
- // - Buffer is declared to have constant shape,
- // but we try to bind it to a different one.
- // - Buffer is declared to be compact(no strides)
- // but this binded region is a subregion of
- // a matrix(tensor), which means it requires strides.
- //
- // We do support a few relaxed case, such as bindingx
- // region with shape [1, 1, n, m] to buffer with shape [n, m]
- Stmt HandleBufferBindScope(const AttrStmtNode* op) {
- Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
- ICHECK_EQ(arr.size(), 2U);
- const BufferNode* buffer = arr[0].as<BufferNode>();
- const BufferNode* target = arr[1].as<BufferNode>();
- const CallNode* tuple = op->value.as<CallNode>();
- ICHECK(buffer && target);
- ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
- auto key = GetRef<Buffer>(target);
-
- auto it = buf_map_.find(key);
- ICHECK(it != buf_map_.end()) << "Cannot find buffer of " << key;
- const BufferEntry& be = it->second;
- ICHECK(!be.released);
- ICHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
- Array<PrimExpr> begins, extents;
- if (be.bounds.size() != 0) {
- ICHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
- for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
- begins.push_back(tuple->args[2 * i] - be.bounds[i]->min);
- extents.push_back(tuple->args[2 * i + 1]);
- }
- } else {
- for (size_t i = 0; i < tuple->args.size(); i += 2) {
- begins.push_back(tuple->args[i]);
- auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]);
- extents.push_back(new_extent);
- }
- }
- Buffer slice = be.buffer.MakeSlice(begins, extents);
- if (buffer->strides.size() == 0) {
- ICHECK_EQ(slice->strides.size(), 0U)
- << "Trying to bind compact buffer to strided one strides=" <<
slice->strides;
- } else {
- slice = slice.MakeStrideView();
- }
- // start binding
- ArgBinder binder(&var_remap_);
- binder.BindBuffer(Downcast<Buffer>(arr[0]), slice, buffer->name, true);
- // Apply the remaps
- Stmt body = MergeNest(binder.asserts(), op->body);
- body = MergeNest(binder.init_nest(), body);
- body = this->VisitStmt(body);
- // remove the binds
- for (const Var& v : binder.defs()) {
- var_remap_.erase(v.get());
- }
- return body;
- }
-
// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
@@ -437,25 +1303,10 @@ class StorageFlattener : public StmtExprMutator {
struct BufferEntry {
// the buffer of storage
Buffer buffer;
- // the bounds of realization, can be null, means everything
- Region bounds;
// Whether the buffer is external
bool external{false};
- // Whether we are out of allocation bounds and buffer get released.
- bool released{false};
- // relative index
- inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
- if (bounds.size() != 0) {
- Array<PrimExpr> index;
- ICHECK_EQ(bounds.size(), args.size());
- for (size_t i = 0; i < bounds.size(); ++i) {
- index.push_back(args[i] - bounds[i]->min);
- }
- return index;
- } else {
- return args;
- }
- }
+ // Whether the buffer is currently in scope.
+ bool in_scope{true};
};
bool ShapeIsValid(const Array<PrimExpr>& shape) {
@@ -484,10 +1335,6 @@ class StorageFlattener : public StmtExprMutator {
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual>
buf_map_;
- // Dimension alignment
- std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash,
ObjectPtrEqual> dim_align_;
- // The current thread scope.
- std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
// bounds populator. We really need the analyzer from it.
@@ -499,6 +1346,68 @@ class StorageFlattener : public StmtExprMutator {
bool create_bound_attributes_{false};
};
+/*!
+ * \brief Simplify assert statements.
+ *
+ * If an assert statement can be statically verified to be true,
+ * remove the assert statement. Otherwise, keep the assert statement
+ * unmodified.
+ */
+class AssertSimplifier : public StmtMutator {
+ public:
+ explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer) {}
+
+ Stmt VisitStmt_(const AssertStmtNode* op) final {
+ Stmt stmt = StmtMutator::VisitStmt_(op);
+ op = stmt.as<AssertStmtNode>();
+
+ PrimExpr condition = bound_analyzer_->Simplify(op->condition);
+ if (is_one(condition)) {
+ return op->body;
+ }
+
+ return stmt;
+ }
+
+ private:
+ IRVisitorWithAnalyzer* bound_analyzer_;
+};
+
+// The specific tensor data layout is not determined before
+// StorageFlatten pass. We use buffer_bind_scope
+// to specify before hand we want to bind a subregion
+// of tensor to a symbolic buffer, which get used in extern.
+//
+// Example:
+//
+// realize A in range [i*4, extent=10) {
+// bind Ab to A in [i*4+1, extent=4) {
+// call_func(Ab.ptr, Ab.shape[0])
+// }
+// }
+//
+// After StorageFlatten
+//
+// alloc A[10]
+// call(A + 1, 4)
+//
+// Buffer is a protocol to declare specific
+// data layout and shape we expect.
+// So this function need to check:
+// - If the bind range is within the realize range
+// - If we can match the requirement of buffer
+// - Remap variables such as Ab.ptr to the actual value.
+//
+// Here are a few possible failure cases:
+// - Buffer is declared to have constant shape,
+// but we try to bind it to a different one.
+// - Buffer is declared to be compact(no strides)
+// but this binded region is a subregion of
+// a matrix(tensor), which means it requires strides.
+//
+// We do support a few relaxed case, such as binding a
+// region with shape [1, 1, n, m] to buffer with shape [n, m]
PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool
create_bound_attributes) {
// Only apply this pass to TIR from TE schedules
Optional<Bool> from_legacy_te_schedule =
func->GetAttr("from_legacy_te_schedule", Bool(false));
@@ -507,8 +1416,22 @@ PrimFunc StorageFlatten(PrimFunc func, int
cache_line_size, bool create_bound_at
IRVisitorWithAnalyzer bound_analyzer;
bound_analyzer(fptr->body);
+
+ fptr->body = BufferShapeLegalize(fptr->buffer_map,
&bound_analyzer)(std::move(fptr->body));
+
+ auto stride_legalize = BufferStrideLegalize(fptr->buffer_map,
&bound_analyzer);
+ fptr->body = stride_legalize(std::move(fptr->body));
+ fptr->buffer_map = stride_legalize.UpdatedExternBufferMap();
+
+ fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));
+
+ fptr->body = BufferBindUnwrapper(fptr->buffer_map,
&bound_analyzer)(std::move(fptr->body));
+
fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size,
create_bound_attributes,
&bound_analyzer)(std::move(fptr->body));
+
+ fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body));
+
return func;
} else {
return func;