Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r807375704



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -294,22 +308,61 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, 
op->span);
-    }
-    return StmtExprMutator::VisitExpr_(op);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use 
BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
-      return Store(merged_buf_var_, value, offset + index, op->predicate, 
op->span);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use 
BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    if (IsDynamicSharedMemory(node->buffer->data)) {
+      ICHECK_EQ(node->indices.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, 
"
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      Array<PrimExpr> indices = {node->indices[0] +
+                                 this->GetBufferOffset(node->buffer->data, 
node->buffer->dtype)};
+
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetUpdatedBuffer(node->buffer);
+      writer->indices = indices;
     }
-    return StmtExprMutator::VisitStmt_(op);
+
+    return node;
+  }
+
+  Buffer GetUpdatedBuffer(Buffer buffer) {
+    auto key = buffer.get();
+    auto it = buffer_remap_.find(key);
+    if (it != buffer_remap_.end()) {
+      return it->second;
+    }
+
+    if (IsDynamicSharedMemory(buffer->data)) {
+      ICHECK_EQ(buffer->shape.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, 
"
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      auto writer = buffer.CopyOnWrite();
+      writer->data = merged_buf_var_;

Review comment:
       I think the aliasing is necessary overall, because the dynamic shared 
memory buffers aren't necessarily the same underlying type.  That said, I can 
imagine a later improvement that would identify buffers of the same type and 
them merge those, but I didn't want to add complexity for the first time around.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to