Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r807375704
##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -294,22 +308,61 @@ class DynamicSharedMemoryRewriter : public
StmtExprMutator {
}
PrimExpr VisitExpr_(const LoadNode* op) final {
- if (IsDynamicSharedMemory(op->buffer_var)) {
- PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
- PrimExpr index = StmtExprMutator::VisitExpr(op->index);
- return Load(op->dtype, merged_buf_var_, offset + index, op->predicate,
op->span);
- }
- return StmtExprMutator::VisitExpr_(op);
+ LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use
BufferLoadNode instead.";
+ return PrimExpr();
}
Stmt VisitStmt_(const StoreNode* op) final {
- if (IsDynamicSharedMemory(op->buffer_var)) {
- PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
- PrimExpr index = StmtExprMutator::VisitExpr(op->index);
- PrimExpr value = StmtExprMutator::VisitExpr(op->value);
- return Store(merged_buf_var_, value, offset + index, op->predicate,
op->span);
+ LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use
BufferStoreNode instead.";
+ return Stmt();
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ return VisitBufferAccess(std::move(node));
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ return VisitBufferAccess(std::move(node));
+ }
+
+ template <typename Node>
+ Node VisitBufferAccess(Node node) {
+ if (IsDynamicSharedMemory(node->buffer->data)) {
+ ICHECK_EQ(node->indices.size(), 1)
+ << "MergeDynamicSharedMemoryAllocations expects flat memory buffers,
"
+ << "and is to be run after "
+ << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+ Array<PrimExpr> indices = {node->indices[0] +
+ this->GetBufferOffset(node->buffer->data,
node->buffer->dtype)};
+
+ auto writer = node.CopyOnWrite();
+ writer->buffer = GetUpdatedBuffer(node->buffer);
+ writer->indices = indices;
}
- return StmtExprMutator::VisitStmt_(op);
+
+ return node;
+ }
+
+ Buffer GetUpdatedBuffer(Buffer buffer) {
+ auto key = buffer.get();
+ auto it = buffer_remap_.find(key);
+ if (it != buffer_remap_.end()) {
+ return it->second;
+ }
+
+ if (IsDynamicSharedMemory(buffer->data)) {
+ ICHECK_EQ(buffer->shape.size(), 1)
+ << "MergeDynamicSharedMemoryAllocations expects flat memory buffers,
"
+ << "and is to be run after "
+ << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+ auto writer = buffer.CopyOnWrite();
+ writer->data = merged_buf_var_;
Review comment:
I think the aliasing is necessary overall, because the dynamic shared
memory buffers aren't necessarily the same underlying type. That said, I can
imagine a later improvement that would identify buffers of the same type and
them merge those, but I didn't want to add complexity for the first time around.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]