gemini-code-assist[bot] commented on code in PR #18876:
URL: https://github.com/apache/tvm/pull/18876#discussion_r2892322985


##########
src/s_tir/transform/lower_thread_allreduce.cc:
##########
@@ -787,18 +808,106 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
   std::vector<const CommReducerNode*> reduce_combiner_;
   // The load remap
   std::unordered_map<const VarNode*, PrimExpr> load_remap_;
+  // Internal analyzer
+  arith::Analyzer analyzer_;
+
+ public:
+  // These members are public for post-processing by DeferredRemapper.
   // Allocate remap
   std::unordered_map<const VarNode*, Buffer> alloc_remap_;
   // BufferVar remap
   std::unordered_map<const VarNode*, Var> var_remap_;
   // Buffer remap
   std::unordered_map<const BufferNode*, Buffer> buf_remap_;
-  // Internal analyzer
-  arith::Analyzer analyzer_;
+  // Pending AllocBuffer original data pointers (for flat IR deferred 
remapping)
+  std::vector<const VarNode*> pending_alloc_buffers_;
 };
 
 namespace transform {
 
+/*!
+ * \brief Post-processing pass to apply deferred AllocBuffer remappings.
+ *
+ * In flat IR, AllocBuffer nodes may be visited before the alloc_remap_ is 
populated
+ * (since MakeAllreduce runs when Evaluate is visited, which is later in the 
flat sequence).
+ * This pass walks the result and applies any pending remappings.
+ */
+/*!
+ * \brief Post-processing pass to apply deferred remappings for flat IR.
+ *
+ * Handles AllocBuffer, DeclBuffer, and BufferLoad nodes whose remappings
+ * were not available during the main traversal.
+ */
+class DeferredRemapper : public StmtExprMutator {
+ public:
+  DeferredRemapper(const std::unordered_map<const VarNode*, Buffer>& 
alloc_remap,
+                   const std::unordered_map<const VarNode*, Var>& var_remap,
+                   const std::unordered_map<const BufferNode*, Buffer>& 
buf_remap,
+                   const std::vector<const VarNode*>& pending)
+      : alloc_remap_(alloc_remap), var_remap_(var_remap), 
buf_remap_(buf_remap) {
+    for (const VarNode* ptr : pending) {
+      pending_set_.insert(ptr);
+    }
+  }
+
+  bool HasPendingRemaps() const {
+    for (const VarNode* ptr : pending_set_) {
+      if (alloc_remap_.count(ptr)) return true;
+    }
+    return false;
+  }
+
+  Stmt VisitStmt_(const AllocBufferNode* op) final {
+    auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
+    const VarNode* data_ptr = op->buffer->data.get();
+    if (pending_set_.count(data_ptr)) {
+      if (auto it = alloc_remap_.find(data_ptr); it != alloc_remap_.end()) {
+        const Buffer& replacement = it->second;
+        node.CopyOnWrite()->buffer = replacement;
+        if (replacement.scope() == "shared") {
+          Stmt volatile_attr =
+              AttrStmt(replacement->data, tir::attr::volatile_scope, 1, 
Evaluate(0));
+          return SeqStmt::Flatten(node, volatile_attr);

Review Comment:
   
![security-high](https://www.gstatic.com/codereviewagent/security-high-priority.svg)
 ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Similar to the issue in `RemapAllocBuffer`, the `DeferredRemapper` also 
incorrectly applies the `volatile_scope` attribute to an empty body. This 
ensures that the attribute is immediately out of scope for all subsequent 
operations on the buffer, breaking synchronization on GPU backends. This 
critical flaw can lead to race conditions or incorrect optimizations, as the 
`volatile` attribute for shared memory is crucial for correctness on GPU 
targets. The `AttrStmt` needs to wrap the code that actually uses the shared 
buffer.



##########
src/tir/transform/lower_tvm_builtin.cc:
##########
@@ -280,21 +271,20 @@ class BuiltinLower : public StmtExprMutator {
                              cast(DataType::Int(32), device_id_.value()), 
op->buffer->data});
     Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), 
throw_last_error);
 
-    Stmt body = op->body;
-    body = SeqStmt::Flatten(body, free_stmt);
-    body = SeqStmt::Flatten(alloc_nullptr_check, body);
-
-    body = AttrStmt(op->buffer->data, attr::storage_alignment,
-                    make_const(DataType::Int(32), 
runtime::kTempAllocaAlignment), body);
-    body = SeqStmt({Bind(op->buffer->data,
-                         Call(op->buffer->data.dtype(), 
Op::Get("tir.TVMBackendAllocWorkspace"),
-                              {cast(DataType::Int(32), device_type_.value()),
-                               cast(DataType::Int(32), device_id_.value()), 
total_bytes,
-                               IntImm(DataType::Int(32), 
op->buffer->dtype.code()),
-                               IntImm(DataType::Int(32), 
op->buffer->dtype.bits())})),
-                    body});
-
-    return body;
+    // Push free to enclosing scope's pending_frees (LIFO ordering preserved).
+    scope_.Current().pending_frees.push_back(free_stmt);
+
+    Stmt alloc_bind = Bind(op->buffer->data,
+                           Call(op->buffer->data.dtype(), 
Op::Get("tir.TVMBackendAllocWorkspace"),
+                                {cast(DataType::Int(32), device_type_.value()),
+                                 cast(DataType::Int(32), device_id_.value()), 
total_bytes,
+                                 IntImm(DataType::Int(32), 
op->buffer->dtype.code()),
+                                 IntImm(DataType::Int(32), 
op->buffer->dtype.bits())}));
+
+    Stmt result = AttrStmt(op->buffer->data, attr::storage_alignment,
+                           make_const(DataType::Int(32), 
runtime::kTempAllocaAlignment),
+                           SeqStmt({alloc_bind, alloc_nullptr_check}));
+    return result;

Review Comment:
   
![security-medium](https://www.gstatic.com/codereviewagent/security-medium-priority.svg)
 ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `storage_alignment` attribute is now applied only to the allocation and 
nullptr check statements, rather than wrapping the entire body of the 
allocation as it did previously. Since `AttrStmt` is scoped to its body, the 
alignment information will be lost for all subsequent loads and stores using 
this buffer. While this may primarily impact performance by preventing the use 
of aligned instructions, it represents a loss of metadata that could affect 
correctness on architectures with strict alignment requirements.



##########
src/s_tir/transform/lower_thread_allreduce.cc:
##########
@@ -76,16 +76,32 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
     }
   }
   Stmt VisitStmt_(const AllocBufferNode* op) final {
+    // In flat IR, alloc_remap_ may not yet be populated when this AllocBuffer 
is visited
+    // (the remap is set up by MakeAllreduce which runs during 
AttrStmt/Evaluate visit
+    // that appears later in the sequence). We record the original data 
pointer and
+    // attempt the remap; if it's not ready, the post-processing pass will 
handle it.
+    const VarNode* orig_data_ptr = op->buffer->data.get();
     auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
 
-    if (auto it = alloc_remap_.find(node->buffer->data.get()); it != 
alloc_remap_.end()) {
-      Buffer buf = Downcast<Buffer>(it->second);
-      auto write_ptr = node.CopyOnWrite();
-      write_ptr->buffer = buf;
+    if (auto it = alloc_remap_.find(orig_data_ptr); it != alloc_remap_.end()) {
+      return RemapAllocBuffer(node, it->second);
+    }
+    // Record for deferred remapping (flat IR case)
+    pending_alloc_buffers_.emplace_back(orig_data_ptr);
+    return node;
+  }
 
-      if (buf.scope() == "shared") {
-        write_ptr->body = AttrStmt(buf->data, tir::attr::volatile_scope, 1, 
write_ptr->body);
-      }
+  /*!
+   * \brief Remap an AllocBuffer node to use the replacement buffer.
+   * \param node The original AllocBuffer node.
+   * \param replacement The replacement buffer.
+   * \return The remapped statement(s).
+   */
+  Stmt RemapAllocBuffer(AllocBuffer node, const Buffer& replacement) {
+    node.CopyOnWrite()->buffer = replacement;
+    if (replacement.scope() == "shared") {
+      Stmt volatile_attr = AttrStmt(replacement->data, 
tir::attr::volatile_scope, 1, Evaluate(0));
+      return SeqStmt::Flatten(node, volatile_attr);

Review Comment:
   
![security-high](https://www.gstatic.com/codereviewagent/security-high-priority.svg)
 ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The refactoring of `AllocBuffer` to a flat statement has introduced a 
critical synchronization bug. The `volatile_scope` attribute is applied via an 
`AttrStmt` with an empty body (`Evaluate(0)`), which is then placed as a 
sibling to the `AllocBuffer` in a `SeqStmt`. In TIR, an `AttrStmt` only applies 
to its `body`. Consequently, the `volatile_scope` attribute will not be in 
scope for any subsequent statements that perform loads or stores on the shared 
memory buffer. This leads to the GPU codegen omitting the `volatile` qualifier, 
causing incorrect caching in registers, data races, and incorrect results 
during multi-threaded reductions. This issue renders the `volatile` attribute 
ineffective and can lead to incorrect code generation. A potential fix involves 
identifying the scope where this buffer is used (e.g., the enclosing loop or 
function body) and wrapping that entire scope with the `AttrStmt`, similar to 
how `WarpMemoryRewriter` is implemented in `lower_warp_memory.cc`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to