kimm240 commented on code in PR #18636:
URL: https://github.com/apache/tvm/pull/18636#discussion_r2703912041


##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block 
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
     var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
   }
 
-  // 2. Change init to epilogue value based on epilogue type
-  BufferStore new_init_store;
-  if (epilogue_type_ == EpilogueType::BiasReLU) {
-    // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU 
semantics
-    PrimExpr init_value = Substitute(epilogue_addend_, var_map);
-    PrimExpr zero = tir::make_zero(init_value.dtype());
-    new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, 
zero),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else if (epilogue_type_ == EpilogueType::Clipping) {
-    // For Clipping, init should be min(max(init_value, lower), upper)
-    // Since init is typically 0, this becomes min(max(0, lower), upper)
-    PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
-    PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, 
var_map)),
-                                Substitute(clipping_upper_, var_map));
-    new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else {
-    // Bias: D[vi, vj] = C[vi, vj]
-    new_init_store = BufferStore(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  }
+  // 2. Generalized init transformation: substitute reduction buffer load with 
identity element (0)
+  // Create a substituter to replace reduction_buffer_load_ with identity 
element
+  class InitSubstituter : public ExprMutator {
+   public:
+    InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+        : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+    PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+      BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+      if (load->buffer.same_as(target_buffer_)) {
+        return identity_elem_;
+      }
+      return load;
+    }
+
+   private:
+    Buffer target_buffer_;
+    PrimExpr identity_elem_;
+  };
+
+  // Identity element for reduction (assumed to be 0 for addition-based 
reductions)
+  PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
+
+  // Substitute reduction buffer load with identity element
+  InitSubstituter init_subst(inlined_buffer_, identity_elem);
+  PrimExpr init_epilogue = init_subst(epilogue_expression_);
+
+  // Apply index mapping
+  init_epilogue = Substitute(init_epilogue, var_map);
+
+  // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj])
+  arith::Analyzer analyzer;
+  init_epilogue = analyzer.Simplify(init_epilogue);
+
+  BufferStore new_init_store = BufferStore(epilogue_output_buffer_, 
init_epilogue,
+                                           
Substitute(epilogue_output_indices_, var_map));
   new_block->init = new_init_store;
 
-  // 3. Replace output buffer from temp to D in body
-  class BufferReplacer : public StmtExprMutator {
+  // 3. Generalized update transformation: apply epilogue expression with 
reduction buffer replaced
+  // If reduction buffer load's parent is Add and other operand is not a 
reduction buffer,
+  // remove that operand (bias addend) from update expression
+  class UpdateSubstituter : public StmtExprMutator {
    public:
-    BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, 
DataType dtype,
-                   PrimExpr clipping_lower = PrimExpr(), PrimExpr 
clipping_upper = PrimExpr())
+    UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const 
Buffer& reduction_buf,
+                      const PrimExpr& epilogue_expr, const 
std::unordered_map<Var, Var>& var_map)
         : old_buffer_(old_buf),
           new_buffer_(new_buf),
-          epilogue_type_(epilogue_type),
-          dtype_(dtype),
-          clipping_lower_(clipping_lower),
-          clipping_upper_(clipping_upper) {}
+          reduction_buffer_(reduction_buf),
+          epilogue_expression_(epilogue_expr),
+          var_map_(var_map) {}
 
     Stmt VisitStmt_(const BufferStoreNode* op) final {
       BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
       if (store->buffer.same_as(old_buffer_)) {
-        PrimExpr new_value = store->value;
-        // For ReLU, apply max per iteration to match per-iteration ReLU 
semantics
-        if (epilogue_type_ == EpilogueType::BiasReLU) {
-          PrimExpr zero = tir::make_zero(dtype_);
-          new_value = Max(new_value, zero);
-        } else if (epilogue_type_ == EpilogueType::Clipping) {
-          // For Clipping, apply min(max(value, lower), upper) per iteration
-          new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
-        }
+        // Replace old_buffer_ in store->value with new_buffer_ to get the 
reduction update
+        // expression This ensures store->value references new_buffer_ instead 
of old_buffer_
+        class ReductionUpdateReplacer : public ExprMutator {
+         public:
+          ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf)
+              : old_buffer_(old_buf), new_buffer_(new_buf) {}
+
+          PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+            BufferLoad load = 
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+            if (load->buffer.same_as(old_buffer_)) {
+              return BufferLoad(new_buffer_, load->indices);
+            }
+            return load;
+          }
+
+         private:
+          Buffer old_buffer_;
+          Buffer new_buffer_;
+        };
+
+        ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_);
+        PrimExpr reduction_update = reduction_replacer(store->value);
+
+        // Generalized approach: apply epilogue expression with reduction 
buffer load replaced
+        // If reduction buffer load's direct parent is Add and the other 
operand is not a reduction
+        // buffer, remove that operand (bias addend) from the update expression
+        class GeneralizedEpilogueApplier : public ExprMutator {
+         public:
+          GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& 
reduction_buf,
+                                     const PrimExpr& replacement)
+              : target_buffer_(target_buf),
+                reduction_buffer_(reduction_buf),
+                replacement_(replacement),
+                found_target_load_(false) {}
+
+          PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+            BufferLoad load = 
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+            if (load->buffer.same_as(target_buffer_)) {
+              found_target_load_ = true;
+              // Check if parent is Add (will be checked in VisitExpr_(const 
AddNode*))
+              return replacement_;
+            }
+            return load;
+          }
+
+          PrimExpr VisitExpr_(const AddNode* op) final {

Review Comment:
   To fix Case 1 issue, I add code in commit da8354b.
   Case 2 is already considered in this PR.
   
   I add test cases for both of issues, and check tests well run.
   If there are other  issues  to  solve, let me know and then fix it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to