Lunderberg commented on a change in pull request #9091:
URL: https://github.com/apache/tvm/pull/9091#discussion_r715837497
##########
File path: src/tir/transforms/storage_flatten.cc
##########
@@ -50,6 +50,913 @@ using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
+/* Make buffer realize extents and buffer shapes consistent
+ *
+ * For external buffers, verify that the extents of BufferRealize
+ * nodes match the shape of the external buffer. For internal
+ * buffers, rewrite the shape of the Buffer objects to match the
+ * extent of the BufferRealize, and rewrite indices of
+ * BufferLoad/BufferStore nodes to match.
+ */
+class BufferShapeLegalize : public StmtExprMutator {
+ public:
+ explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
+ IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer) {
+ for (auto kv : extern_buffer_map) {
+ extern_buffers_.insert(kv.second);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ // External buffers should not be changed.
+ if (extern_buffers_.count(op->buffer)) {
+ ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
+ << "External buffer realize has mismatched dimension";
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferRealizeNode>();
+ ICHECK(op);
+
+ for (size_t i = 0; i < op->bounds.size(); i++) {
+ PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] ==
op->bounds[i]->extent);
+ std::ostringstream ss;
+ ss << "Dim " << i << " of external buffer " << op->buffer->name << "
has shape "
+ << op->buffer->shape[i] << ", but is only realized for extent " <<
op->bounds[i]->extent;
+ if (auto eq_int = eq.as<IntImmNode>()) {
+ ICHECK(eq_int->value) << ss.str();
+ } else {
+ stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
+ }
+ }
+ return stmt;
+ }
+
+ // Compute the new buffer shape, new realization bounds, and the
+ // offsets to be applied to buffer access.
+ Array<PrimExpr> realized_shape;
+ Array<PrimExpr> realized_begins;
+ Array<Range> new_bounds;
+ for (size_t i = 0; i < op->bounds.size(); i++) {
+ const Range& bound = op->bounds[i];
+ realized_shape.push_back(bound->extent);
+ realized_begins.push_back(bound->min);
+ new_bounds.push_back({0, bound->extent});
+ }
+
+ Buffer key = op->buffer;
+
+ Buffer buf = op->buffer;
+ auto write_ptr = buf.CopyOnWrite();
+ write_ptr->shape = realized_shape;
+
+ {
+ InternalBufferRemap remap;
+ remap.remap_to = buf;
+ remap.realized_begins = realized_begins;
+ remap.in_scope = true;
+ internal_buf_map_[key] = remap;
+ }
+
+ Stmt stmt = BufferRealize(buf, new_bounds, op->condition,
this->VisitStmt(op->body), op->span);
+
+ internal_buf_map_.at(key).in_scope = false;
+
+ return stmt;
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+ ICHECK(op);
+
+ auto it = internal_buf_map_.find(op->buffer);
+ if (it != internal_buf_map_.end()) {
+ const InternalBufferRemap& entry = it->second;
+ ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer";
+ ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+ << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+ Array<PrimExpr> new_indices;
+ for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+ new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+ }
+
+ BufferStore updated = GetRef<BufferStore>(op);
+ auto write_ptr = updated.CopyOnWrite();
+ write_ptr->indices = new_indices;
+ write_ptr->buffer = entry.remap_to;
+ stmt = updated;
+ }
+
+ return stmt;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+ ICHECK(op);
+
+ auto it = internal_buf_map_.find(op->buffer);
+ if (it != internal_buf_map_.end()) {
+ const InternalBufferRemap& entry = it->second;
+ ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer";
+ ICHECK_EQ(entry.realized_begins.size(), op->indices.size())
+ << "Inconsistent dimensions for buffer " << op->buffer->name;
+
+ Array<PrimExpr> new_indices;
+ for (size_t i = 0; i < entry.realized_begins.size(); i++) {
+ new_indices.push_back(op->indices[i] - entry.realized_begins[i]);
+ }
+
+ BufferLoad updated = GetRef<BufferLoad>(op);
+ auto write_ptr = updated.CopyOnWrite();
+ write_ptr->indices = new_indices;
+ write_ptr->buffer = entry.remap_to;
+ expr = updated;
+ }
+
+ return expr;
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->node->IsInstance<tir::BufferNode>()) {
+ // Visit body before checking internal_buf_map_, because we
+ // don't know if the BufferNode needs to be changed until we
+ // look in the body for a BufferRealizeNode with different
+ // extents.
+ Stmt body = this->VisitStmt(op->body);
+
+ Buffer buffer = Downcast<tir::Buffer>(op->node);
+ auto it = internal_buf_map_.find(buffer);
+ if (it != internal_buf_map_.end()) {
+ buffer = it->second.remap_to;
+ return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
+ }
+ return AttrStmt(buffer, op->attr_key, op->value, body);
+
+ } else if (op->attr_key == attr::buffer_bind_scope) {
+ return HandleBufferBindScope(op);
+ }
+
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ private:
+ Stmt HandleBufferBindScope(const AttrStmtNode* op) {
Review comment:
Makes sense, and comments have been added.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]