This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b2b639fa1 [TIR][Schedule]Generalize fuseReductionEpilogue to support 
arbitrary epilogue expressions (#18636)
4b2b639fa1 is described below

commit 4b2b639fa187684637be55865f7cc84c47deb2eb
Author: kimm240 <[email protected]>
AuthorDate: Mon Feb 2 20:26:57 2026 +0900

    [TIR][Schedule]Generalize fuseReductionEpilogue to support arbitrary 
epilogue expressions (#18636)
    
    ## Major Changes for Generalization
    
    ### 1. Pattern Matching Removal
    
    **Removed Items:**
    - `EpilogueType` enum (Bias, BiasReLU, Clipping)
    - `AnalyzeEpiloguePattern()` function
    - Pattern-specific branching logic
    
    **Current Approach:**
    - Directly process the entire epilogue expression without pattern
    matching
    
    ### 2. Store Entire Epilogue Expression
    
    - Store the entire epilogue expression in `epilogue_expression_`
    - Use the expression directly without pattern analysis
    
    ```cpp
    // Store the epilogue expression and reduction buffer load
    epilogue_expression_ = inlined_store_->value;
    reduction_buffer_load_ = loads[0];
    ```
    ### 3. Generalized Init Transformation
    
    - Replace reduction buffer load with identity element (0)
    - Apply to the entire expression to generate init value
    
    ```cpp
    InitSubstituter init_subst(inlined_buffer_, identity_elem);
    PrimExpr init_epilogue = init_subst(epilogue_expression_);
    // Simplify: 0 + C[vi, vj] -> C[vi, vj]
    ```
    
    **Examples:**
    - `temp + C` → `0 + C` → `C` (simplify)
    - `max(temp + C, 0)` → `max(0 + C, 0)` → `max(C, 0)`
    - `min(max(temp, lower), upper)` → `min(max(0, lower), upper)`
    
    ### 4. Generalized Update Transformation
    
    - Replace reduction buffer load with reduction update
    - If parent is Add and the other operand is not a reduction buffer →
    treat as bias addend and remove
    - Otherwise → apply expression as-is
    
    ```cpp
    class GeneralizedEpilogueApplier : public ExprMutator {
      // Replace reduction buffer load with reduction update
      // Automatically detect and remove bias addend in Add nodes
      // Automatically support other activation functions
    };
    ```
    
    ## Results and Verification
    
    ### Existing Tests Pass
    
    All existing tests pass, maintaining backward compatibility:
    
    - `test_fuse_reduction_epilogue_basic`
    - `test_fuse_reduction_epilogue_fp32`
    - `test_fuse_reduction_epilogue_numerical_correctness`
    - `test_fuse_reduction_epilogue_multiple_epilogue`
    - `test_matmul_bias_relu`
    - `test_matmul_bias_relu_correctness_unified`
    - `test_matmul_clipping`
    - `test_matmul_clipping_correctness_unified`
    - Other commutative variants tests
    
    Total: All 15 tests pass
    
    ---------
    
    Signed-off-by: Hyun Gyu Kim <[email protected]>
    Co-authored-by: Hyun Gyu Kim <[email protected]>
---
 src/tir/schedule/primitive/compute_inline.cc       | 487 ++++++++++++---------
 .../test_tir_schedule_fuse_reduction_epilogue.py   |  59 +++
 2 files changed, 334 insertions(+), 212 deletions(-)

diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index 0e3fc5e2a2..c60954deaf 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -986,15 +986,8 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
 
 /*!
  * \brief Helper to fuse epilogue block into reduction block
- * Analyzes epilogue pattern and transforms reduction init/update
+ * Uses generalized approach to handle any epilogue expression without pattern 
matching
  */
-// Epilogue type enumeration
-enum class EpilogueType {
-  Bias,      // temp + C
-  BiasReLU,  // max(temp + C, 0)
-  Clipping,  // min(max(temp, lower), upper)
-};
-
 class ReductionEpilogueFuser : public BaseInliner {
  public:
   explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const 
SBlockNode* reduction_block,
@@ -1002,8 +995,7 @@ class ReductionEpilogueFuser : public BaseInliner {
                                   const StmtSRef& scope_root_sref)
       : BaseInliner(reduction_buffer, epilogue_block_realize->block, 
scope_root_sref),
         reduction_block_(reduction_block),
-        epilogue_block_(epilogue_block_realize->block.get()),
-        epilogue_type_(EpilogueType::Bias) {
+        epilogue_block_(epilogue_block_realize->block.get()) {
     // Disable opaque access check for epilogue fusion
     // Epilogue blocks can read multiple buffers (temp + bias), which is 
allowed
     has_opaque_access = false;
@@ -1023,7 +1015,6 @@ class ReductionEpilogueFuser : public BaseInliner {
                                    const SBlockRealizeNode* reduction_realize);
 
  private:
-  bool AnalyzeEpiloguePattern(const PrimExpr& value);
   bool IsReductionBlock(const SBlockNode* block);
   void ExtractEpilogueInfo();
   // Helper function to extract BufferLoad nodes from BufferStore
@@ -1052,15 +1043,16 @@ class ReductionEpilogueFuser : public BaseInliner {
 
   const SBlockNode* reduction_block_;
   const SBlockNode* epilogue_block_;
-  PrimExpr epilogue_addend_{nullptr};                      // C[vi, vj] in D = 
temp + C
-  Buffer epilogue_output_buffer_{nullptr};                 // Output buffer D
+  // Generalized approach: store the entire epilogue expression
+  PrimExpr epilogue_expression_{
+      nullptr};  // The entire epilogue expression (e.g., temp + C, max(temp + 
C, 0))
+  const BufferLoadNode* reduction_buffer_load_{
+      nullptr};                             // The reduction buffer load in 
epilogue expression
+  Buffer epilogue_output_buffer_{nullptr};  // Output buffer D
   ffi::Array<PrimExpr> epilogue_output_indices_{nullptr};  // Indices of D[vi, 
vj]
   BufferRegion epilogue_output_region_{nullptr};           // Write region of D
-  Buffer epilogue_addend_buffer_{nullptr};                 // Addend buffer C
-  BufferRegion epilogue_addend_region_{nullptr};           // Read region of C
-  EpilogueType epilogue_type_;                             // Type of epilogue 
operation
-  PrimExpr clipping_lower_{nullptr};                       // Lower bound for 
clipping
-  PrimExpr clipping_upper_{nullptr};                       // Upper bound for 
clipping
+  Buffer epilogue_addend_buffer_{nullptr};        // Additional buffer (e.g., 
bias buffer C)
+  BufferRegion epilogue_addend_region_{nullptr};  // Read region of additional 
buffer
 };
 
 bool ReductionEpilogueFuser::BodyPatternAllowFusion(const SBlockRealize& 
epilogue_block_realize) {
@@ -1083,166 +1075,112 @@ bool 
ReductionEpilogueFuser::BodyPatternAllowFusion(const SBlockRealize& epilogu
     return false;
   }
 
-  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
-  //    D[i,j] = min(max(temp[i,j], lower), upper)
-  if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
-    // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or 
Clipping)
-    return false;
-  }
-
-  // 5. Verify temp appears exactly once in the epilogue pattern
-  // This ensures correctness for all supported patterns (Bias, BiasReLU, 
Clipping)
-  // The reduction result buffer must be used exactly once in the epilogue 
expression
+  // 4. Generalized approach: store the entire epilogue expression
+  // Verify reduction buffer appears exactly once (required for fusion 
correctness)
   if (loads.size() != 1) {
     // Failure: The reduction result (temp) must be used exactly once in the
     // epilogue expression for fusion.
     return false;
   }
 
-  // 6. Check if producer is a reduction block
-  if (!IsReductionBlock(reduction_block_)) {
-    // Failure: producer is not a reduction block
-    return false;
-  }
-
-  // 7. Extract epilogue information (output buffer, indices, regions, etc.)
-  ExtractEpilogueInfo();
+  // Store the epilogue expression and reduction buffer load
+  epilogue_expression_ = inlined_store_->value;
+  reduction_buffer_load_ = loads[0];
 
-  return true;
-}
+  // 5. Reject epilogues that scale the reduction result with non-additive ops
+  // For example, (reduce_out * 2.0) + C[i] is not a valid bias-style epilogue.
+  // We only allow the reduction result to be combined via Add/Min/Max shells.
+  class ScalingDetector : public ExprVisitor {
+   public:
+    explicit ScalingDetector(const Buffer& buffer) : buffer_(buffer) {}
 
-bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
-  // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
-  if (const auto* add = value.as<AddNode>()) {
-    const auto* load_a = add->a.as<BufferLoadNode>();
-    const auto* load_b = add->b.as<BufferLoadNode>();
+    bool HasScaling(const PrimExpr& expr) {
+      has_scaling_ = false;
+      VisitExpr(expr);
+      return has_scaling_;
+    }
 
-    bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
-    bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+   private:
+    // Helper to check if a subtree contains a load from the reduction buffer
+    bool ContainsTarget(const PrimExpr& expr) {
+      class TargetFinder : public ExprVisitor {
+       public:
+        explicit TargetFinder(const Buffer& buffer) : buffer_(buffer) {}
+
+        bool Find(const PrimExpr& e) {
+          found_ = false;
+          VisitExpr(e);
+          return found_;
+        }
 
-    // Ensure exactly one operand is from the reduction buffer
-    if (a_is_target != b_is_target) {
-      epilogue_addend_ = a_is_target ? add->b : add->a;
-      epilogue_type_ = EpilogueType::Bias;
-      return true;
-    }
-  }
+       private:
+        void VisitExpr_(const BufferLoadNode* op) final {
+          if (op->buffer.same_as(buffer_)) {
+            found_ = true;
+            return;
+          }
+          ExprVisitor::VisitExpr_(op);
+        }
 
-  // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], 
upper), lower) (Clipping)
-  // Handle all commutative variants of min/max at each level.
+        Buffer buffer_;
+        bool found_{false};
+      };
 
-  // Helper to check if an expression is a load from the reduction buffer, and
-  // return the other operand as `other` if so.
-  auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const 
PrimExpr& b,
-                                               PrimExpr* other) -> bool {
-    if (const auto* load_a = a.as<BufferLoadNode>()) {
-      if (load_a->buffer.same_as(inlined_buffer_)) {
-        *other = b;
-        return true;
-      }
-    }
-    if (const auto* load_b = b.as<BufferLoadNode>()) {
-      if (load_b->buffer.same_as(inlined_buffer_)) {
-        *other = a;
-        return true;
-      }
+      TargetFinder finder(buffer_);
+      return finder.Find(expr);
     }
-    return false;
-  };
 
-  // Check for min(max(temp, lower), upper) and commutative variants
-  if (const auto* min_node = value.as<MinNode>()) {
-    const MaxNode* max_node = nullptr;
-    PrimExpr upper;
-    // Try both (a, b) as possible positions of the inner max
-    if ((max_node = min_node->a.as<MaxNode>())) {
-      upper = min_node->b;
-    } else if ((max_node = min_node->b.as<MaxNode>())) {
-      upper = min_node->a;
-    }
-    if (max_node != nullptr) {
-      PrimExpr lower;
-      if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
-        clipping_lower_ = lower;
-        clipping_upper_ = upper;
-        epilogue_type_ = EpilogueType::Clipping;
-        return true;
+    void VisitExpr_(const MulNode* op) final {
+      if (has_scaling_) return;
+      // If either operand subtree contains the reduction buffer load,
+      // we treat this as invalid scaling of the reduction result.
+      if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+        has_scaling_ = true;
+        return;
       }
+      ExprVisitor::VisitExpr_(op);
     }
-  }
 
-  // Check for max(min(temp[i,j], upper), lower) and commutative variants
-  if (const auto* max_node = value.as<MaxNode>()) {
-    const MinNode* min_node = nullptr;
-    PrimExpr lower;
-    // Try both (a, b) as possible positions of the inner min
-    if ((min_node = max_node->a.as<MinNode>())) {
-      lower = max_node->b;
-    } else if ((min_node = max_node->b.as<MinNode>())) {
-      lower = max_node->a;
-    }
-    if (min_node != nullptr) {
-      PrimExpr upper;
-      if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
-        clipping_lower_ = lower;
-        clipping_upper_ = upper;
-        epilogue_type_ = EpilogueType::Clipping;
-        return true;
+    void VisitExpr_(const DivNode* op) final {
+      if (has_scaling_) return;
+      if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+        has_scaling_ = true;
+        return;
       }
+      ExprVisitor::VisitExpr_(op);
     }
-  }
 
-  // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) 
(BiasReLU)
-  // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
-  if (const auto* max_node = value.as<MaxNode>()) {
-    // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
-    // Support both integer and float zero constants.
-    const PrimExpr* add_candidate = nullptr;
-    bool is_zero_const = false;
-    auto is_zero_expr = [](const PrimExpr& expr) -> bool {
-      if (tir::is_zero(expr)) {
-        return true;
-      }
-      if (const auto* float_imm = expr.as<FloatImmNode>()) {
-        return float_imm->value == 0.0;
+    void VisitExpr_(const ModNode* op) final {
+      if (has_scaling_) return;
+      if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+        has_scaling_ = true;
+        return;
       }
-      return false;
-    };
-
-    if (is_zero_expr(max_node->a)) {
-      is_zero_const = true;
-      add_candidate = &max_node->b;
-    } else if (is_zero_expr(max_node->b)) {
-      is_zero_const = true;
-      add_candidate = &max_node->a;
+      ExprVisitor::VisitExpr_(op);
     }
 
-    if (is_zero_const && add_candidate != nullptr) {
-      if (const auto* add = add_candidate->as<AddNode>()) {
-        const auto* load_a = add->a.as<BufferLoadNode>();
-        const auto* load_b = add->b.as<BufferLoadNode>();
-
-        bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
-        bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+    Buffer buffer_;
+    bool has_scaling_{false};
+  };
 
-        // Ensure exactly one operand is from the reduction buffer
-        if (a_is_target != b_is_target) {
-          epilogue_addend_ = a_is_target ? add->b : add->a;
-          epilogue_type_ = EpilogueType::BiasReLU;
-          return true;
-        }
-      } else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
-        // Handle bias-free ReLU: max(temp, 0) or max(0, temp)
-        if (load->buffer.same_as(inlined_buffer_)) {
-          epilogue_addend_ = tir::make_zero(load->dtype);
-          epilogue_type_ = EpilogueType::BiasReLU;
-          return true;
-        }
-      }
+  {
+    ScalingDetector detector(inlined_buffer_);
+    if (detector.HasScaling(inlined_store_->value)) {
+      // Failure: Non-additive scaling of the reduction result is not supported
+      return false;
     }
   }
 
-  return false;
+  // 6. Check if producer is a reduction block
+  if (!IsReductionBlock(reduction_block_)) {
+    // Failure: producer is not a reduction block
+    return false;
+  }
+
+  // 7. Extract epilogue information (output buffer, indices, regions, etc.)
+  ExtractEpilogueInfo();
+
+  return true;
 }
 
 bool ReductionEpilogueFuser::IsReductionBlock(const SBlockNode* block) {
@@ -1268,12 +1206,29 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() {
     }
   }
 
-  // Extract epilogue addend buffer and region from epilogue_addend_
-  if (const auto* load = epilogue_addend_.as<BufferLoadNode>()) {
-    epilogue_addend_buffer_ = load->buffer;
+  // Generalized approach: extract all non-reduction buffers from epilogue 
expression
+  // Find all buffers in epilogue expression (except the reduction buffer)
+  struct BufferExtractor : public ExprVisitor {
+    void VisitExpr_(const BufferLoadNode* load) final {
+      if (!load->buffer.same_as(reduction_buffer)) {
+        other_buffers.insert(load->buffer.get());
+      }
+      ExprVisitor::VisitExpr_(load);
+    }
+    Buffer reduction_buffer;
+    std::unordered_set<const BufferNode*> other_buffers;
+  } extractor;
+  extractor.reduction_buffer = inlined_buffer_;
+  extractor(epilogue_expression_);
+
+  // Extract the first non-reduction buffer and its region
+  // In most cases, there's one additional buffer (e.g., bias buffer)
+  if (!extractor.other_buffers.empty()) {
+    const BufferNode* first_buffer = *extractor.other_buffers.begin();
+    epilogue_addend_buffer_ = ffi::GetRef<Buffer>(first_buffer);
     // Find the read region from epilogue block reads
     for (const BufferRegion& read : epilogue_block_->reads) {
-      if (read->buffer.same_as(epilogue_addend_buffer_)) {
+      if (read->buffer.get() == first_buffer) {
         epilogue_addend_region_ = read;
         break;
       }
@@ -1308,53 +1263,163 @@ SBlock 
ReductionEpilogueFuser::CreateFusedReductionBlock(
     var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
   }
 
-  // 2. Change init to epilogue value based on epilogue type
-  BufferStore new_init_store;
-  if (epilogue_type_ == EpilogueType::BiasReLU) {
-    // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU 
semantics
-    PrimExpr init_value = Substitute(epilogue_addend_, var_map);
-    PrimExpr zero = tir::make_zero(init_value.dtype());
-    new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, 
zero),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else if (epilogue_type_ == EpilogueType::Clipping) {
-    // For Clipping, init should be min(max(init_value, lower), upper)
-    // Since init is typically 0, this becomes min(max(0, lower), upper)
-    PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
-    PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, 
var_map)),
-                                Substitute(clipping_upper_, var_map));
-    new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else {
-    // Bias: D[vi, vj] = C[vi, vj]
-    new_init_store = BufferStore(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  }
+  // 2. Generalized init transformation: substitute reduction buffer load with 
identity element (0)
+  // Create a substituter to replace reduction_buffer_load_ with identity 
element
+  class InitSubstituter : public ExprMutator {
+   public:
+    InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+        : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+    PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+      BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+      if (load->buffer.same_as(target_buffer_)) {
+        return identity_elem_;
+      }
+      return load;
+    }
+
+   private:
+    Buffer target_buffer_;
+    PrimExpr identity_elem_;
+  };
+
+  // Identity element for reduction (assumed to be 0 for addition-based 
reductions)
+  PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
+
+  // Substitute reduction buffer load with identity element
+  InitSubstituter init_subst(inlined_buffer_, identity_elem);
+  PrimExpr init_epilogue = init_subst(epilogue_expression_);
+
+  // Apply index mapping
+  init_epilogue = Substitute(init_epilogue, var_map);
+
+  // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj])
+  arith::Analyzer analyzer;
+  init_epilogue = analyzer.Simplify(init_epilogue);
+
+  BufferStore new_init_store = BufferStore(epilogue_output_buffer_, 
init_epilogue,
+                                           
Substitute(epilogue_output_indices_, var_map));
   new_block->init = new_init_store;
 
-  // 3. Replace output buffer from temp to D in body
-  class BufferReplacer : public StmtExprMutator {
+  // 3. Generalized update transformation: apply epilogue expression with 
reduction buffer replaced
+  // If reduction buffer load's parent is Add and other operand is not a 
reduction buffer,
+  // remove that operand (bias addend) from update expression
+  class UpdateSubstituter : public StmtExprMutator {
    public:
-    BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, 
DataType dtype,
-                   PrimExpr clipping_lower = PrimExpr(), PrimExpr 
clipping_upper = PrimExpr())
+    UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const 
Buffer& reduction_buf,
+                      const PrimExpr& epilogue_expr, const 
std::unordered_map<Var, Var>& var_map)
         : old_buffer_(old_buf),
           new_buffer_(new_buf),
-          epilogue_type_(epilogue_type),
-          dtype_(dtype),
-          clipping_lower_(clipping_lower),
-          clipping_upper_(clipping_upper) {}
+          reduction_buffer_(reduction_buf),
+          epilogue_expression_(epilogue_expr),
+          var_map_(var_map) {}
 
     Stmt VisitStmt_(const BufferStoreNode* op) final {
       BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
       if (store->buffer.same_as(old_buffer_)) {
-        PrimExpr new_value = store->value;
-        // For ReLU, apply max per iteration to match per-iteration ReLU 
semantics
-        if (epilogue_type_ == EpilogueType::BiasReLU) {
-          PrimExpr zero = tir::make_zero(dtype_);
-          new_value = Max(new_value, zero);
-        } else if (epilogue_type_ == EpilogueType::Clipping) {
-          // For Clipping, apply min(max(value, lower), upper) per iteration
-          new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
-        }
+        // Replace old_buffer_ in store->value with new_buffer_ to get the 
reduction update
+        // expression This ensures store->value references new_buffer_ instead 
of old_buffer_
+        class ReductionUpdateReplacer : public ExprMutator {
+         public:
+          ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf)
+              : old_buffer_(old_buf), new_buffer_(new_buf) {}
+
+          PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+            BufferLoad load = 
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+            if (load->buffer.same_as(old_buffer_)) {
+              return BufferLoad(new_buffer_, load->indices);
+            }
+            return load;
+          }
+
+         private:
+          Buffer old_buffer_;
+          Buffer new_buffer_;
+        };
+
+        ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_);
+        PrimExpr reduction_update = reduction_replacer(store->value);
+
+        // Generalized approach: apply epilogue expression with reduction 
buffer load replaced
+        // If reduction buffer load's direct parent is Add and the other 
operand is not a reduction
+        // buffer, remove that operand (bias addend) from the update expression
+        class GeneralizedEpilogueApplier : public ExprMutator {
+         public:
+          GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& 
reduction_buf,
+                                     const PrimExpr& replacement)
+              : target_buffer_(target_buf),
+                reduction_buffer_(reduction_buf),
+                replacement_(replacement),
+                found_target_load_(false) {}
+
+          PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+            BufferLoad load = 
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+            if (load->buffer.same_as(target_buffer_)) {
+              found_target_load_ = true;
+              // Check if parent is Add (will be checked in VisitExpr_(const 
AddNode*))
+              return replacement_;
+            }
+            return load;
+          }
+
+          PrimExpr VisitExpr_(const AddNode* op) final {
+            // Visit children first to see if we find the target buffer load
+            bool found_before = found_target_load_;
+            found_target_load_ = false;
+
+            PrimExpr a = VisitExpr(op->a);
+            bool found_in_a = found_target_load_;
+            found_target_load_ = false;
+
+            PrimExpr b = VisitExpr(op->b);
+            bool found_in_b = found_target_load_;
+
+            // If target buffer load was found in this Add node
+            if (found_in_a || found_in_b) {
+              // Check if the other operand is NOT from the reduction buffer
+              // If so, it's likely a bias addend that should be removed in 
update
+              bool other_is_reduction = false;
+              if (found_in_a) {
+                // Check if b is from reduction buffer
+                if (const auto* load_b = b.as<BufferLoadNode>()) {
+                  other_is_reduction = 
load_b->buffer.same_as(reduction_buffer_);
+                }
+                if (!other_is_reduction) {
+                  // b is the bias addend, remove it
+                  return a;
+                }
+              } else {  // found_in_b
+                // Check if a is from reduction buffer
+                if (const auto* load_a = a.as<BufferLoadNode>()) {
+                  other_is_reduction = 
load_a->buffer.same_as(reduction_buffer_);
+                }
+                if (!other_is_reduction) {
+                  // a is the bias addend, remove it
+                  return b;
+                }
+              }
+              // If other operand is also from reduction buffer, keep the Add
+              return Add(a, b);
+            }
+
+            // Target buffer load not found in this Add, return as is
+            found_target_load_ = found_before;
+            return Add(a, b);
+          }
+
+         private:
+          const Buffer& target_buffer_;
+          const Buffer& reduction_buffer_;
+          const PrimExpr& replacement_;
+          bool found_target_load_;
+        };
+
+        GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, 
reduction_update);
+        PrimExpr new_value = applier(epilogue_expression_);
+
+        // Apply index mapping
+        new_value = Substitute(new_value, var_map_);
+
         return BufferStore(new_buffer_, new_value, store->indices);
       }
       return store;
@@ -1371,19 +1436,16 @@ SBlock 
ReductionEpilogueFuser::CreateFusedReductionBlock(
    private:
     Buffer old_buffer_;
     Buffer new_buffer_;
-    EpilogueType epilogue_type_;
-    DataType dtype_;
-    PrimExpr clipping_lower_;
-    PrimExpr clipping_upper_;
+    Buffer reduction_buffer_;
+    PrimExpr epilogue_expression_;
+    std::unordered_map<Var, Var> var_map_;
   };
 
-  DataType dtype = epilogue_output_buffer_->dtype;
-  PrimExpr clipping_lower_subst =
-      epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, 
var_map) : PrimExpr();
-  PrimExpr clipping_upper_subst =
-      epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, 
var_map) : PrimExpr();
-  BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, 
epilogue_type_, dtype,
-                          clipping_lower_subst, clipping_upper_subst);
+  // Apply index mapping to epilogue expression first
+  PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map);
+
+  UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, 
inlined_buffer_,
+                             epilogue_expr_mapped, var_map);
   new_block->body = replacer(reduction_block->body);
 
   // 4. Update write regions
@@ -1398,21 +1460,22 @@ SBlock 
ReductionEpilogueFuser::CreateFusedReductionBlock(
   }
   new_block->writes = new_writes;
 
-  // 5. Update read regions (C first, then A, B)
+  // 5. Update read regions: add all buffers from epilogue expression (except 
reduction buffer)
   ffi::Array<BufferRegion> new_reads;
   std::unordered_set<const BufferNode*> read_bufs;
 
-  // Add C buffer read first (used in init)
-  if (epilogue_addend_buffer_.defined()) {
-    new_reads.push_back(BufferRegion(epilogue_addend_buffer_,
-                                     
Substitute(epilogue_addend_region_->region, var_map)));
-    read_bufs.insert(epilogue_addend_buffer_.get());
+  // Add all non-reduction buffers from epilogue expression
+  for (const BufferRegion& read : epilogue_block_->reads) {
+    if (!read->buffer.same_as(inlined_buffer_)) {
+      new_reads.push_back(BufferRegion(read->buffer, Substitute(read->region, 
var_map)));
+      read_bufs.insert(read->buffer.get());
+    }
   }
 
-  // Add existing read regions (A, B, etc.)
+  // Add existing read regions from reduction block (A, B, etc.)
   for (const BufferRegion& read : reduction_block->reads) {
     if (!read->buffer.same_as(inlined_buffer_)) {
-      // Only add non-temp buffers
+      // Only add non-temp buffers that haven't been added yet
       if (read_bufs.find(read->buffer.get()) == read_bufs.end()) {
         new_reads.push_back(read);
         read_bufs.insert(read->buffer.get());
diff --git 
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py 
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
index 7210237f83..88b1062621 100644
--- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
+++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
@@ -214,5 +214,64 @@ def test_fuse_reduction_epilogue_multiple_epilogue():
     assert mod is not None
 
 
[email protected]_func
+def matmul_bias_invalid_multiple_use_before(
+    A: T.Buffer((16, 16), "int8"),
+    B: T.Buffer((16, 16), "int8"),
+    C1: T.Buffer((16, 16), "int32"),
+    C2: T.Buffer((16, 16), "int32"),
+    D: T.Buffer((16, 16), "int32"),
+) -> None:
+    """Epilogue uses the reduction result twice; fusion must be rejected."""
+    temp = T.alloc_buffer((16, 16), dtype="int32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.sblock("multiply"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.int32(0)
+            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * 
T.cast(B[vj, vk], "int32")
+    for i, j in T.grid(16, 16):
+        with T.sblock("bad_epilogue"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            # temp[vi, vj] is used twice in the epilogue expression
+            D[vi, vj] = (temp[vi, vj] + C1[vi, vj]) * (temp[vi, vj] + C2[vi, 
vj])
+
+
+def test_fuse_reduction_epilogue_reject_multiple_use():
+    """fusion should be rejected when the reduction result appears more than 
once."""
+    sch = tir.Schedule(matmul_bias_invalid_multiple_use_before, 
debug_mask="all")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.fuse_reduction_epilogue("multiply", "bad_epilogue")
+
+
[email protected]_func
+def matmul_bias_invalid_scaling_before(
+    A: T.Buffer((16, 16), "int8"),
+    B: T.Buffer((16, 16), "int8"),
+    C: T.Buffer((16, 16), "int32"),
+    D: T.Buffer((16, 16), "int32"),
+) -> None:
+    """Epilogue scales the reduction result; fusion must be rejected."""
+    temp = T.alloc_buffer((16, 16), dtype="int32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.sblock("multiply"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.int32(0)
+            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * 
T.cast(B[vj, vk], "int32")
+    for i, j in T.grid(16, 16):
+        with T.sblock("scaled_epilogue"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            # temp[vi, vj] is scaled by 2 before adding bias; this must not be 
fused.
+            D[vi, vj] = temp[vi, vj] * T.int32(2) + C[vi, vj]
+
+
+def test_fuse_reduction_epilogue_reject_scaling():
+    """fusion should be rejected when the reduction result is scaled by 
Mul/Div/Mod."""
+    sch = tir.Schedule(matmul_bias_invalid_scaling_before, debug_mask="all")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.fuse_reduction_epilogue("multiply", "scaled_epilogue")
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to