gemini-code-assist[bot] commented on code in PR #19605:
URL: https://github.com/apache/tvm/pull/19605#discussion_r3299801604


##########
python/tvm/s_tir/pipeline.py:
##########
@@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: 
tvm.transform.PassContext) -> tvm.ir.I
             passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
+                s_tir.transform.MergeSharedMemoryAllocations(),

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   ### Scoping and Correctness Bug with Multi-Kernel PrimFuncs\n\nMoving 
`MergeSharedMemoryAllocations` before `AnnotateDeviceRegions` (and thus before 
`SplitHostDevice`) introduces a critical correctness/scoping bug when a single 
`PrimFunc` contains multiple device regions (i.e., multiple `AttrStmt` blocks 
with `attr::thread_extent`).\n\n#### Why this happens:\n1. **Global Merging**: 
`MergeSharedMemoryAllocations` is a `PrimFunc` pass. It collects *all* shared 
memory allocations across the entire `PrimFunc` and plans to merge them into a 
single `merged_buf_var_` (e.g., `buf_dyn_shmem`).\n2. **Single Allocation 
Site**: In `SharedMemoryRewriter::VisitStmt_(const AttrStmtNode* op)`, the pass 
only allocates the merged buffer at the *first* `thread_extent` block it 
encounters, setting `allocated_ = true`:\n   ```cpp\n   if (op->attr_key == 
tirx::attr::thread_extent && !allocated_) {\n     ...\n     allocated_ = 
true;\n     ...\n     return AttrStmt(op->node, op->attr_key, op->value, new_
 body, op->span);\n   }\n   ```\n3. **Scoping Violation**: Any subsequent 
`thread_extent` blocks (representing other device kernels in the same 
`PrimFunc`) will have their original `AllocBuffer` statements removed, and 
their accesses rewritten to use `merged_buf_var_`. However, because 
`merged_buf_var_` is only allocated inside the body of the *first* 
`thread_extent` block, it is not in scope for the subsequent blocks.\n4. 
**Undefined Variables after Split**: When `SplitHostDevice` later splits these 
`thread_extent` blocks into separate device functions, the second device 
function will reference `merged_buf_var_` (e.g., `buf_dyn_shmem`) which is 
completely undefined in its scope, leading to compilation failures or runtime 
crashes.\n\n#### Suggested Solution:\nTo keep `AnnotateDeviceRegions`, 
`SplitHostDevice`, and `LowerDeviceKernelLaunch` consecutive, 
`MergeSharedMemoryAllocations` either needs to be refactored to support 
multiple `thread_extent` blocks (by tracking and allocating a
  merged buffer per `thread_extent` scope), or it must remain after 
`SplitHostDevice` but before `LowerDeviceKernelLaunch` (which would prevent 
running them consecutively).



##########
python/tvm/s_tir/backend/adreno/pipeline.py:
##########
@@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: 
tvm.transform.PassContext) -> tvm.ir.I
             passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
+                s_tir.transform.MergeSharedMemoryAllocations(),

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   ### Scoping and Correctness Bug with Multi-Kernel PrimFuncs\n\nMoving 
`MergeSharedMemoryAllocations` before `AnnotateDeviceRegions` (and thus before 
`SplitHostDevice`) introduces a critical correctness/scoping bug when a single 
`PrimFunc` contains multiple device regions (i.e., multiple `AttrStmt` blocks 
with `attr::thread_extent`).\n\n#### Why this happens:\n1. **Global Merging**: 
`MergeSharedMemoryAllocations` is a `PrimFunc` pass. It collects *all* shared 
memory allocations across the entire `PrimFunc` and plans to merge them into a 
single `merged_buf_var_` (e.g., `buf_dyn_shmem`).\n2. **Single Allocation 
Site**: In `SharedMemoryRewriter::VisitStmt_(const AttrStmtNode* op)`, the pass 
only allocates the merged buffer at the *first* `thread_extent` block it 
encounters, setting `allocated_ = true`:\n   ```cpp\n   if (op->attr_key == 
tirx::attr::thread_extent && !allocated_) {\n     ...\n     allocated_ = 
true;\n     ...\n     return AttrStmt(op->node, op->attr_key, op->value, new_
 body, op->span);\n   }\n   ```\n3. **Scoping Violation**: Any subsequent 
`thread_extent` blocks (representing other device kernels in the same 
`PrimFunc`) will have their original `AllocBuffer` statements removed, and 
their accesses rewritten to use `merged_buf_var_`. However, because 
`merged_buf_var_` is only allocated inside the body of the *first* 
`thread_extent` block, it is not in scope for the subsequent blocks.\n4. 
**Undefined Variables after Split**: When `SplitHostDevice` later splits these 
`thread_extent` blocks into separate device functions, the second device 
function will reference `merged_buf_var_` (e.g., `buf_dyn_shmem`) which is 
completely undefined in its scope, leading to compilation failures or runtime 
crashes.\n\n#### Suggested Solution:\nTo keep `AnnotateDeviceRegions`, 
`SplitHostDevice`, and `LowerDeviceKernelLaunch` consecutive, 
`MergeSharedMemoryAllocations` either needs to be refactored to support 
multiple `thread_extent` blocks (by tracking and allocating a
  merged buffer per `thread_extent` scope), or it must remain after 
`SplitHostDevice` but before `LowerDeviceKernelLaunch` (which would prevent 
running them consecutively).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to