gemini-code-assist[bot] commented on code in PR #19610:
URL: https://github.com/apache/tvm/pull/19610#discussion_r3303430795


##########
src/s_tir/transform/merge_shared_memory_allocations.cc:
##########
@@ -274,100 +276,140 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
 
 /*!
  * \brief merge the buffers whose live range has no intersection and rewrite 
the body
+ *
+ * Uses a scope-stack design: each thread_extent block (kernel launch) gets its
+ * own KernelScope that owns the merged buffer var and all per-launch 
bookkeeping.
+ * This correctly handles PrimFuncs with multiple sibling thread_extent blocks.
  */
 class SharedMemoryRewriter : public StmtExprMutator {
  public:
-  explicit SharedMemoryRewriter(const std::unordered_map<const VarNode*, 
Buffer>& shmem_allocs,
-                                bool is_dynamic = true)
-      : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} {
-    if (!is_dynamic) {
-      merged_buf_var_ = Var("buf_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared"));
-    }
-  }
+  explicit SharedMemoryRewriter(bool is_dynamic = true) : 
is_dynamic_{is_dynamic} {}
+
+ private:
+  using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
+
+  struct StorageEntry {
+    // The constant size of the buffer in bits, only used if it is constant
+    uint64_t const_nbits{0};
+    // Allocs that shares this entry.
+    // The inner vector means a "layer"
+    // For example, it we need to allocate C in the memory of A and B:
+    // |  A: 4096 bytes |  B: 4096 bytes |
+    // |            C: 8192 bytes        |
+    // Then the allocs = {{A, B}, {C}}
+    std::vector<std::vector<const VarNode*>> allocs;
+  };
+
+  // Event entry in liveness analysis
+  struct EventEntry {
+    // variables we generate
+    std::vector<const VarNode*> gen;
+    // variables we kill
+    std::vector<const VarNode*> kill;
+  };
+
+  /*!
+   * \brief Per-kernel-launch scope holding all state for one thread_extent 
block.
+   */
+  struct KernelScope {
+    // The merged buffer var for THIS kernel launch.
+    Var merged_buf_var;
+    // Total byte size of THIS kernel's merged buffer.
+    PrimExpr merged_alloc_size{0};
+    // Allocations from THIS kernel's subtree.
+    std::unordered_map<const VarNode*, Buffer> shmem_allocs;
+    // Per-buffer byte offset into merged_buf_var.
+    std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets;
+    // Buffer-object remap: original Buffer -> merged-data-var Buffer.
+    std::unordered_map<const BufferNode*, Buffer> buffer_remap;
+    // Has any original alloc in this scope been marked volatile?
+    bool has_volatile_alloc{false};
+    // Liveness data (event_map, alloc_map, const_free_map, sym_free_list) — 
all per-scope.
+    std::unordered_map<const ffi::Object*, EventEntry> event_map;
+    std::multimap<uint64_t, StorageEntry*> const_free_map;
+    std::list<StorageEntry*> sym_free_list;
+    std::unordered_map<const VarNode*, StorageEntry*> alloc_map;
+  };
 
   /*!
-   * \brief plan the memory reuse for all the buffer allocated in the statement
-   * \param stmt the statement
+   * \brief Create a fresh merged buffer Var for a new kernel scope.
+   *        Same name string is fine — Var identity is by pointer, not name.
    */
-  void PlanReuse(const Stmt& stmt, bool is_dynamic = true) {
-    SharedMemLinearAccessPatternFinder finder(is_dynamic);
-    finder(stmt);
-    this->LivenessAnalysis(finder.linear_seq_);
-    this->PlanMemory(finder.linear_seq_);
+  Var MakeMergedBufferVar() {
+    if (is_dynamic_) {
+      return Var("buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), 
"shared.dyn"));
+    } else {
+      return Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), 
"shared"));
+    }
   }
 
- private:
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == tirx::attr::thread_extent && !allocated_) {
-      // Allocate one dynamic shared memory allocation at the beginning of 
thread scope
-      int max_layer_num = 0;
-      std::vector<const StorageEntry*> all_entry;
-      for (const auto& e : const_free_map_) {
-        all_entry.push_back(e.second);
-      }
-      for (const StorageEntry* e : sym_free_list_) {
-        all_entry.push_back(e);
-      }
-      for (const StorageEntry* e : all_entry) {
-        max_layer_num = std::max(max_layer_num, 
static_cast<int>(e->allocs.size()));
-      }
-      // calculate align for each layer of each storage entry.
-      std::vector<int> align(max_layer_num, 0);
-      for (const StorageEntry* e : all_entry) {
-        for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
-          for (const VarNode* buffer : e->allocs[i]) {
-            const Buffer& buf = shmem_allocs_.at(buffer);
-            align[i] = std::max(align[i], buf->dtype.bytes());
-          }
-        }
-      }
-      // calculate offset for each buffer based on the align of each layer
-      for (const StorageEntry* e : all_entry) {
-        PrimExpr max_inner_offset = 0;
-        for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
-          PrimExpr inner_offset = 0;
-          for (const VarNode* buffer : e->allocs[i]) {
-            const Buffer& buf = shmem_allocs_.at(buffer);
-            ffi::Array<PrimExpr> alloc_shape = GetBufferAllocationShape(buf);
-            int align_bytes = std::max(align[i], buf->dtype.bytes());
-            if (buf->data_alignment > 0) {
-              TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0)
-                  << "The alignment of the buffer is not a multiple of the 
data type size.";
-              align_bytes = buf->data_alignment;
-            }
-            PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes();
-            inner_offset +=
-                indexmod(align_bytes - indexmod(merged_alloc_size_ + 
inner_offset, align_bytes),
-                         align_bytes);
-            buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
-            inner_offset += buffer_bytes;
-          }
-          max_inner_offset = max(max_inner_offset, inner_offset);
-        }
-        merged_alloc_size_ += max_inner_offset;
-      }
+    if (op->attr_key == tirx::attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+
+      // 1. Push a fresh scope.
+      scope_stack_.emplace_back();
+      KernelScope& scope = scope_stack_.back();
+      scope.merged_buf_var = MakeMergedBufferVar();
+
+      // 2. Collect shmem allocs that belong to THIS subtree.
+      AllocateCollector collector(is_dynamic_);
+      collector(op->body);
+      scope.shmem_allocs = std::move(collector.shmem_allocs_);
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   If a `thread_extent` block contains at most one shared memory allocation, 
there is no need to perform liveness analysis, memory planning, or buffer 
rewriting for that block. We can optimize this by bailing out early when 
`scope.shmem_allocs.size() <= 1`, which avoids unnecessary AST mutations and 
keeps the generated IR cleaner.
   
   ```suggestion
         // 2. Collect shmem allocs that belong to THIS subtree.
         AllocateCollector collector(is_dynamic_);
         collector(op->body);
         scope.shmem_allocs = std::move(collector.shmem_allocs_);
   
         if (scope.shmem_allocs.size() <= 1) {
           scope_stack_.pop_back();
           in_thread_env_ = false;
           return StmtMutator::VisitStmt_(op);
         }
   ```



##########
src/s_tir/transform/merge_shared_memory_allocations.cc:
##########
@@ -685,51 +777,28 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
     // normal free.
     if (e->const_nbits != 0) {
-      const_free_map_.insert({e->const_nbits, e});
+      scope.const_free_map.insert({e->const_nbits, e});
     } else {
-      sym_free_list_.push_back(e);
+      scope.sym_free_list.push_back(e);
     }
   }
+
   // Whether enable dynamic analysis.
   bool is_dynamic_{true};
-  // The var for the merged buffer
-  Var merged_buf_var_{"buf_dyn_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
-  // The mapping from the original buffer var to its Buffer
-  std::unordered_map<const VarNode*, Buffer> shmem_allocs_;
-  // The size of the merged buffer
-  PrimExpr merged_alloc_size_{0};
-  // The mapping from the original buffer var to its offset in the merged 
buffer
-  std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
-  // The mapping from the original buffer objects to their location in the 
merged buffer.
-  std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
-  // The flag indicating whether the merged buffer has been allocated
-  bool allocated_{false};
-  // Whether any original shared memory allocation had the volatile annotation
-  bool has_volatile_alloc_{false};
-  // Locations of free ops.
-  std::unordered_map<const ffi::Object*, EventEntry> event_map_;
-  // constant size free map.
-  std::multimap<uint64_t, StorageEntry*> const_free_map_;
-  // symbolic free list, for non constant items.
-  std::list<StorageEntry*> sym_free_list_;
-  // The allocation assign map
-  std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
-  /*! \brief allocator of all the StorageEntry*/
+  // Whether already inside a thread_extent (outermost only).
+  bool in_thread_env_{false};
+  // Stack of per-kernel-launch scopes. Pushed on thread_extent entry, popped 
on exit.
+  std::vector<KernelScope> scope_stack_;
+  /*! \brief allocator of all the StorageEntry (shared across all scopes) */
   support::Arena arena_;
 };
 
 Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
-  AllocateCollector collector;
-  collector(stmt);
-  if (collector.dyn_shmem_allocs_.size() > 1) {
-    SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
-    rewriter.PlanReuse(stmt);
-    stmt = rewriter(std::move(stmt));
-  }
-  if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
-    SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false);
-    rewriter.PlanReuse(stmt, false);
-    stmt = rewriter(std::move(stmt));
+  SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true);
+  stmt = dyn_rewriter(std::move(stmt));
+  if (merge_static_smem) {
+    SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false);
+    stmt = static_rewriter(std::move(stmt));
   }
   return stmt;
 }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   To avoid traversing the entire AST when there are no shared memory 
allocations (or at most one) in the entire function, we should run 
`AllocateCollector` first at the function level and bail out early. This 
restores the fast-path optimization from the original implementation and 
significantly improves compilation performance for functions that do not use 
shared memory.
   
   ```c
   Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
     {
       AllocateCollector collector(/*is_dynamic=*/true);
       collector(stmt);
       if (collector.shmem_allocs_.size() > 1) {
         SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true);
         stmt = dyn_rewriter(std::move(stmt));
       }
     }
     if (merge_static_smem) {
       AllocateCollector collector(/*is_dynamic=*/false);
       collector(stmt);
       if (collector.shmem_allocs_.size() > 1) {
         SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false);
         stmt = static_rewriter(std::move(stmt));
       }
     }
     return stmt;
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to