Lunderberg commented on code in PR #14021:
URL: https://github.com/apache/tvm/pull/14021#discussion_r1146293911
##########
include/tvm/tir/transform.h:
##########
@@ -437,10 +437,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*
* \endcode
*
- *
+ * \param is_strict ensure the compacted shape always smaller than the
original shape.
+ * otherwise it allows to grow the shape to match actual accessed buffer
regions.
* \return The pass.
*/
-TVM_DLL Pass CompactBufferAllocation();
+TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
Review Comment:
Should the `is_strict` argument be a parameter in the `PassContext` instead?
##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -469,7 +507,8 @@ class BufferCompactor : public StmtExprMutator {
// Step 0. Check there is no Init part.
ICHECK(!op->init.defined());
// Step 1. Reallocate and rewrite alloc_buffers, also update
BufferAllocInfo.
- Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+ Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
Review Comment:
Nit: Personal preference, but I think lambda expressions tend to be more
readable than the equivalent `std::bind` expression. Here, this could be
rewritten as `op->alloc_buffers.Map([this](const Buffer& buf) { return
RewriteAllocBuffer(buf); });`
##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -481,47 +520,67 @@ class BufferCompactor : public StmtExprMutator {
return std::move(block);
}
- Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
- Array<Buffer> result;
- result.reserve(buffers.size());
- for (const Buffer& buffer : buffers) {
- auto it = buffer_info_.find(buffer);
- ICHECK(it != buffer_info_.end());
- BufferAllocInfo& info = it->second;
- Array<PrimExpr> shape;
- shape.reserve(info.region.size());
- for (const Range& range : info.region) {
- shape.push_back(range->extent);
- }
- Array<PrimExpr> strides;
- if (info.dim_aligns.size()) {
- ICHECK(info.dim_aligns.size() == shape.size());
- strides.resize(shape.size());
- PrimExpr stride = make_const(shape[0].dtype(), 1);
- for (size_t i = shape.size(); i != 0; --i) {
- size_t dim = i - 1;
- if (info.dim_aligns[dim].align_factor != 0) {
- PrimExpr factor = make_const(stride.dtype(),
info.dim_aligns[dim].align_factor);
- PrimExpr offset = make_const(stride.dtype(),
info.dim_aligns[dim].align_offset);
- stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
- }
- strides.Set(dim, stride);
- stride = stride * shape[dim];
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ Buffer new_buffer = RewriteAllocBuffer(op->buffer);
+ auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buffer);
+ n->body = VisitStmt(op->body);
+ return DeclBuffer(n);
+ }
+
+ Stmt VisitStmt_(const AllocateNode* op) final {
Review Comment:
Good point. The potential issue I was picturing was an
`AllocateConst`/`DeclBuffer` pair, where only the `DeclBuffer` would be
rewritten. Now that the `Allocate` is used to identify locations that should
be rewritten, rather than the `DeclBuffer`, this case can no longer occur.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]