Lunderberg commented on code in PR #14021:
URL: https://github.com/apache/tvm/pull/14021#discussion_r1146293911


##########
include/tvm/tir/transform.h:
##########
@@ -437,10 +437,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
  *
  *  \endcode
  *
- *
+ * \param is_strict ensure the compacted shape always smaller than the 
original shape.
+ *   otherwise it allows to grow the shape to match actual accessed buffer 
regions.
  * \return The pass.
  */
-TVM_DLL Pass CompactBufferAllocation();
+TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);

Review Comment:
   Should the `is_strict` argument be a parameter in the `PassContext` instead?



##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -469,7 +507,8 @@ class BufferCompactor : public StmtExprMutator {
     // Step 0. Check there is no Init part.
     ICHECK(!op->init.defined());
     // Step 1. Reallocate and rewrite alloc_buffers, also update 
BufferAllocInfo.
-    Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+    Array<Buffer> alloc_buffers = op->alloc_buffers.Map(

Review Comment:
   Nit: Personal preference, but I think lambda expressions tend to be more 
readable than the equivalent `std::bind` expression.  Here, this could be 
rewritten as `op->alloc_buffers.Map([this](const Buffer& buf) { return 
RewriteAllocBuffer(buf); });`



##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -481,47 +520,67 @@ class BufferCompactor : public StmtExprMutator {
     return std::move(block);
   }
 
-  Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
-    Array<Buffer> result;
-    result.reserve(buffers.size());
-    for (const Buffer& buffer : buffers) {
-      auto it = buffer_info_.find(buffer);
-      ICHECK(it != buffer_info_.end());
-      BufferAllocInfo& info = it->second;
-      Array<PrimExpr> shape;
-      shape.reserve(info.region.size());
-      for (const Range& range : info.region) {
-        shape.push_back(range->extent);
-      }
-      Array<PrimExpr> strides;
-      if (info.dim_aligns.size()) {
-        ICHECK(info.dim_aligns.size() == shape.size());
-        strides.resize(shape.size());
-        PrimExpr stride = make_const(shape[0].dtype(), 1);
-        for (size_t i = shape.size(); i != 0; --i) {
-          size_t dim = i - 1;
-          if (info.dim_aligns[dim].align_factor != 0) {
-            PrimExpr factor = make_const(stride.dtype(), 
info.dim_aligns[dim].align_factor);
-            PrimExpr offset = make_const(stride.dtype(), 
info.dim_aligns[dim].align_offset);
-            stride = stride + indexmod(factor + offset - indexmod(stride, 
factor), factor);
-          }
-          strides.Set(dim, stride);
-          stride = stride * shape[dim];
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    Buffer new_buffer = RewriteAllocBuffer(op->buffer);
+    auto n = CopyOnWrite(op);
+    n->buffer = std::move(new_buffer);
+    n->body = VisitStmt(op->body);
+    return DeclBuffer(n);
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {

Review Comment:
   Good point.  The potential issue I was picturing was an 
`AllocateConst`/`DeclBuffer` pair, where only the `DeclBuffer` would be 
rewritten.  Now that the `Allocate` is used to identify locations that should 
be rewritten, rather than the `DeclBuffer`, this case can no longer occur.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to