This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b2b639fa1 [TIR][Schedule]Generalize fuseReductionEpilogue to support
arbitrary epilogue expressions (#18636)
4b2b639fa1 is described below
commit 4b2b639fa187684637be55865f7cc84c47deb2eb
Author: kimm240 <[email protected]>
AuthorDate: Mon Feb 2 20:26:57 2026 +0900
[TIR][Schedule]Generalize fuseReductionEpilogue to support arbitrary
epilogue expressions (#18636)
## Major Changes for Generalization
### 1. Pattern Matching Removal
**Removed Items:**
- `EpilogueType` enum (Bias, BiasReLU, Clipping)
- `AnalyzeEpiloguePattern()` function
- Pattern-specific branching logic
**Current Approach:**
- Directly process the entire epilogue expression without pattern
matching
### 2. Store Entire Epilogue Expression
- Store the entire epilogue expression in `epilogue_expression_`
- Use the expression directly without pattern analysis
```cpp
// Store the epilogue expression and reduction buffer load
epilogue_expression_ = inlined_store_->value;
reduction_buffer_load_ = loads[0];
```
### 3. Generalized Init Transformation
- Replace reduction buffer load with identity element (0)
- Apply to the entire expression to generate init value
```cpp
InitSubstituter init_subst(inlined_buffer_, identity_elem);
PrimExpr init_epilogue = init_subst(epilogue_expression_);
// Simplify: 0 + C[vi, vj] -> C[vi, vj]
```
**Examples:**
- `temp + C` → `0 + C` → `C` (simplify)
- `max(temp + C, 0)` → `max(0 + C, 0)` → `max(C, 0)`
- `min(max(temp, lower), upper)` → `min(max(0, lower), upper)`
### 4. Generalized Update Transformation
- Replace reduction buffer load with reduction update
- If parent is Add and the other operand is not a reduction buffer →
treat as bias addend and remove
- Otherwise → apply expression as-is
```cpp
class GeneralizedEpilogueApplier : public ExprMutator {
// Replace reduction buffer load with reduction update
// Automatically detect and remove bias addend in Add nodes
// Automatically support other activation functions
};
```
## Results and Verification
### Existing Tests Pass
All existing tests pass, maintaining backward compatibility:
- `test_fuse_reduction_epilogue_basic`
- `test_fuse_reduction_epilogue_fp32`
- `test_fuse_reduction_epilogue_numerical_correctness`
- `test_fuse_reduction_epilogue_multiple_epilogue`
- `test_matmul_bias_relu`
- `test_matmul_bias_relu_correctness_unified`
- `test_matmul_clipping`
- `test_matmul_clipping_correctness_unified`
- Other commutative variants tests
Total: All 15 tests pass
---------
Signed-off-by: Hyun Gyu Kim <[email protected]>
Co-authored-by: Hyun Gyu Kim <[email protected]>
---
src/tir/schedule/primitive/compute_inline.cc | 487 ++++++++++++---------
.../test_tir_schedule_fuse_reduction_epilogue.py | 59 +++
2 files changed, 334 insertions(+), 212 deletions(-)
diff --git a/src/tir/schedule/primitive/compute_inline.cc
b/src/tir/schedule/primitive/compute_inline.cc
index 0e3fc5e2a2..c60954deaf 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -986,15 +986,8 @@ void ReverseComputeInline(ScheduleState self, const
StmtSRef& consumer_block_sre
/*!
* \brief Helper to fuse epilogue block into reduction block
- * Analyzes epilogue pattern and transforms reduction init/update
+ * Uses generalized approach to handle any epilogue expression without pattern
matching
*/
-// Epilogue type enumeration
-enum class EpilogueType {
- Bias, // temp + C
- BiasReLU, // max(temp + C, 0)
- Clipping, // min(max(temp, lower), upper)
-};
-
class ReductionEpilogueFuser : public BaseInliner {
public:
explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const
SBlockNode* reduction_block,
@@ -1002,8 +995,7 @@ class ReductionEpilogueFuser : public BaseInliner {
const StmtSRef& scope_root_sref)
: BaseInliner(reduction_buffer, epilogue_block_realize->block,
scope_root_sref),
reduction_block_(reduction_block),
- epilogue_block_(epilogue_block_realize->block.get()),
- epilogue_type_(EpilogueType::Bias) {
+ epilogue_block_(epilogue_block_realize->block.get()) {
// Disable opaque access check for epilogue fusion
// Epilogue blocks can read multiple buffers (temp + bias), which is
allowed
has_opaque_access = false;
@@ -1023,7 +1015,6 @@ class ReductionEpilogueFuser : public BaseInliner {
const SBlockRealizeNode* reduction_realize);
private:
- bool AnalyzeEpiloguePattern(const PrimExpr& value);
bool IsReductionBlock(const SBlockNode* block);
void ExtractEpilogueInfo();
// Helper function to extract BufferLoad nodes from BufferStore
@@ -1052,15 +1043,16 @@ class ReductionEpilogueFuser : public BaseInliner {
const SBlockNode* reduction_block_;
const SBlockNode* epilogue_block_;
- PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D =
temp + C
- Buffer epilogue_output_buffer_{nullptr}; // Output buffer D
+ // Generalized approach: store the entire epilogue expression
+ PrimExpr epilogue_expression_{
+ nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp +
C, 0))
+ const BufferLoadNode* reduction_buffer_load_{
+ nullptr}; // The reduction buffer load in
epilogue expression
+ Buffer epilogue_output_buffer_{nullptr}; // Output buffer D
ffi::Array<PrimExpr> epilogue_output_indices_{nullptr}; // Indices of D[vi,
vj]
BufferRegion epilogue_output_region_{nullptr}; // Write region of D
- Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
- BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
- EpilogueType epilogue_type_; // Type of epilogue
operation
- PrimExpr clipping_lower_{nullptr}; // Lower bound for
clipping
- PrimExpr clipping_upper_{nullptr}; // Upper bound for
clipping
+ Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g.,
bias buffer C)
+ BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional
buffer
};
bool ReductionEpilogueFuser::BodyPatternAllowFusion(const SBlockRealize&
epilogue_block_realize) {
@@ -1083,166 +1075,112 @@ bool
ReductionEpilogueFuser::BodyPatternAllowFusion(const SBlockRealize& epilogu
return false;
}
- // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
- // D[i,j] = min(max(temp[i,j], lower), upper)
- if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
- // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or
Clipping)
- return false;
- }
-
- // 5. Verify temp appears exactly once in the epilogue pattern
- // This ensures correctness for all supported patterns (Bias, BiasReLU,
Clipping)
- // The reduction result buffer must be used exactly once in the epilogue
expression
+ // 4. Generalized approach: store the entire epilogue expression
+ // Verify reduction buffer appears exactly once (required for fusion
correctness)
if (loads.size() != 1) {
// Failure: The reduction result (temp) must be used exactly once in the
// epilogue expression for fusion.
return false;
}
- // 6. Check if producer is a reduction block
- if (!IsReductionBlock(reduction_block_)) {
- // Failure: producer is not a reduction block
- return false;
- }
-
- // 7. Extract epilogue information (output buffer, indices, regions, etc.)
- ExtractEpilogueInfo();
+ // Store the epilogue expression and reduction buffer load
+ epilogue_expression_ = inlined_store_->value;
+ reduction_buffer_load_ = loads[0];
- return true;
-}
+ // 5. Reject epilogues that scale the reduction result with non-additive ops
+ // For example, (reduce_out * 2.0) + C[i] is not a valid bias-style epilogue.
+ // We only allow the reduction result to be combined via Add/Min/Max shells.
+ class ScalingDetector : public ExprVisitor {
+ public:
+ explicit ScalingDetector(const Buffer& buffer) : buffer_(buffer) {}
-bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
- // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
- if (const auto* add = value.as<AddNode>()) {
- const auto* load_a = add->a.as<BufferLoadNode>();
- const auto* load_b = add->b.as<BufferLoadNode>();
+ bool HasScaling(const PrimExpr& expr) {
+ has_scaling_ = false;
+ VisitExpr(expr);
+ return has_scaling_;
+ }
- bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
- bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+ private:
+ // Helper to check if a subtree contains a load from the reduction buffer
+ bool ContainsTarget(const PrimExpr& expr) {
+ class TargetFinder : public ExprVisitor {
+ public:
+ explicit TargetFinder(const Buffer& buffer) : buffer_(buffer) {}
+
+ bool Find(const PrimExpr& e) {
+ found_ = false;
+ VisitExpr(e);
+ return found_;
+ }
- // Ensure exactly one operand is from the reduction buffer
- if (a_is_target != b_is_target) {
- epilogue_addend_ = a_is_target ? add->b : add->a;
- epilogue_type_ = EpilogueType::Bias;
- return true;
- }
- }
+ private:
+ void VisitExpr_(const BufferLoadNode* op) final {
+ if (op->buffer.same_as(buffer_)) {
+ found_ = true;
+ return;
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
- // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j],
upper), lower) (Clipping)
- // Handle all commutative variants of min/max at each level.
+ Buffer buffer_;
+ bool found_{false};
+ };
- // Helper to check if an expression is a load from the reduction buffer, and
- // return the other operand as `other` if so.
- auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const
PrimExpr& b,
- PrimExpr* other) -> bool {
- if (const auto* load_a = a.as<BufferLoadNode>()) {
- if (load_a->buffer.same_as(inlined_buffer_)) {
- *other = b;
- return true;
- }
- }
- if (const auto* load_b = b.as<BufferLoadNode>()) {
- if (load_b->buffer.same_as(inlined_buffer_)) {
- *other = a;
- return true;
- }
+ TargetFinder finder(buffer_);
+ return finder.Find(expr);
}
- return false;
- };
- // Check for min(max(temp, lower), upper) and commutative variants
- if (const auto* min_node = value.as<MinNode>()) {
- const MaxNode* max_node = nullptr;
- PrimExpr upper;
- // Try both (a, b) as possible positions of the inner max
- if ((max_node = min_node->a.as<MaxNode>())) {
- upper = min_node->b;
- } else if ((max_node = min_node->b.as<MaxNode>())) {
- upper = min_node->a;
- }
- if (max_node != nullptr) {
- PrimExpr lower;
- if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
- clipping_lower_ = lower;
- clipping_upper_ = upper;
- epilogue_type_ = EpilogueType::Clipping;
- return true;
+ void VisitExpr_(const MulNode* op) final {
+ if (has_scaling_) return;
+ // If either operand subtree contains the reduction buffer load,
+ // we treat this as invalid scaling of the reduction result.
+ if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+ has_scaling_ = true;
+ return;
}
+ ExprVisitor::VisitExpr_(op);
}
- }
- // Check for max(min(temp[i,j], upper), lower) and commutative variants
- if (const auto* max_node = value.as<MaxNode>()) {
- const MinNode* min_node = nullptr;
- PrimExpr lower;
- // Try both (a, b) as possible positions of the inner min
- if ((min_node = max_node->a.as<MinNode>())) {
- lower = max_node->b;
- } else if ((min_node = max_node->b.as<MinNode>())) {
- lower = max_node->a;
- }
- if (min_node != nullptr) {
- PrimExpr upper;
- if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
- clipping_lower_ = lower;
- clipping_upper_ = upper;
- epilogue_type_ = EpilogueType::Clipping;
- return true;
+ void VisitExpr_(const DivNode* op) final {
+ if (has_scaling_) return;
+ if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+ has_scaling_ = true;
+ return;
}
+ ExprVisitor::VisitExpr_(op);
}
- }
- // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0)
(BiasReLU)
- // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
- if (const auto* max_node = value.as<MaxNode>()) {
- // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
- // Support both integer and float zero constants.
- const PrimExpr* add_candidate = nullptr;
- bool is_zero_const = false;
- auto is_zero_expr = [](const PrimExpr& expr) -> bool {
- if (tir::is_zero(expr)) {
- return true;
- }
- if (const auto* float_imm = expr.as<FloatImmNode>()) {
- return float_imm->value == 0.0;
+ void VisitExpr_(const ModNode* op) final {
+ if (has_scaling_) return;
+ if (ContainsTarget(op->a) || ContainsTarget(op->b)) {
+ has_scaling_ = true;
+ return;
}
- return false;
- };
-
- if (is_zero_expr(max_node->a)) {
- is_zero_const = true;
- add_candidate = &max_node->b;
- } else if (is_zero_expr(max_node->b)) {
- is_zero_const = true;
- add_candidate = &max_node->a;
+ ExprVisitor::VisitExpr_(op);
}
- if (is_zero_const && add_candidate != nullptr) {
- if (const auto* add = add_candidate->as<AddNode>()) {
- const auto* load_a = add->a.as<BufferLoadNode>();
- const auto* load_b = add->b.as<BufferLoadNode>();
-
- bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
- bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+ Buffer buffer_;
+ bool has_scaling_{false};
+ };
- // Ensure exactly one operand is from the reduction buffer
- if (a_is_target != b_is_target) {
- epilogue_addend_ = a_is_target ? add->b : add->a;
- epilogue_type_ = EpilogueType::BiasReLU;
- return true;
- }
- } else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
- // Handle bias-free ReLU: max(temp, 0) or max(0, temp)
- if (load->buffer.same_as(inlined_buffer_)) {
- epilogue_addend_ = tir::make_zero(load->dtype);
- epilogue_type_ = EpilogueType::BiasReLU;
- return true;
- }
- }
+ {
+ ScalingDetector detector(inlined_buffer_);
+ if (detector.HasScaling(inlined_store_->value)) {
+ // Failure: Non-additive scaling of the reduction result is not supported
+ return false;
}
}
- return false;
+ // 6. Check if producer is a reduction block
+ if (!IsReductionBlock(reduction_block_)) {
+ // Failure: producer is not a reduction block
+ return false;
+ }
+
+ // 7. Extract epilogue information (output buffer, indices, regions, etc.)
+ ExtractEpilogueInfo();
+
+ return true;
}
bool ReductionEpilogueFuser::IsReductionBlock(const SBlockNode* block) {
@@ -1268,12 +1206,29 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() {
}
}
- // Extract epilogue addend buffer and region from epilogue_addend_
- if (const auto* load = epilogue_addend_.as<BufferLoadNode>()) {
- epilogue_addend_buffer_ = load->buffer;
+ // Generalized approach: extract all non-reduction buffers from epilogue
expression
+ // Find all buffers in epilogue expression (except the reduction buffer)
+ struct BufferExtractor : public ExprVisitor {
+ void VisitExpr_(const BufferLoadNode* load) final {
+ if (!load->buffer.same_as(reduction_buffer)) {
+ other_buffers.insert(load->buffer.get());
+ }
+ ExprVisitor::VisitExpr_(load);
+ }
+ Buffer reduction_buffer;
+ std::unordered_set<const BufferNode*> other_buffers;
+ } extractor;
+ extractor.reduction_buffer = inlined_buffer_;
+ extractor(epilogue_expression_);
+
+ // Extract the first non-reduction buffer and its region
+ // In most cases, there's one additional buffer (e.g., bias buffer)
+ if (!extractor.other_buffers.empty()) {
+ const BufferNode* first_buffer = *extractor.other_buffers.begin();
+ epilogue_addend_buffer_ = ffi::GetRef<Buffer>(first_buffer);
// Find the read region from epilogue block reads
for (const BufferRegion& read : epilogue_block_->reads) {
- if (read->buffer.same_as(epilogue_addend_buffer_)) {
+ if (read->buffer.get() == first_buffer) {
epilogue_addend_region_ = read;
break;
}
@@ -1308,53 +1263,163 @@ SBlock
ReductionEpilogueFuser::CreateFusedReductionBlock(
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value based on epilogue type
- BufferStore new_init_store;
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
- PrimExpr init_value = Substitute(epilogue_addend_, var_map);
- PrimExpr zero = tir::make_zero(init_value.dtype());
- new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
- Substitute(epilogue_output_indices_,
var_map));
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, init should be min(max(init_value, lower), upper)
- // Since init is typically 0, this becomes min(max(0, lower), upper)
- PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
- PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
- Substitute(clipping_upper_, var_map));
- new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
- Substitute(epilogue_output_indices_,
var_map));
- } else {
- // Bias: D[vi, vj] = C[vi, vj]
- new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_,
var_map));
- }
+ // 2. Generalized init transformation: substitute reduction buffer load with
identity element (0)
+ // Create a substituter to replace reduction_buffer_load_ with identity
element
+ class InitSubstituter : public ExprMutator {
+ public:
+ InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+ : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ return identity_elem_;
+ }
+ return load;
+ }
+
+ private:
+ Buffer target_buffer_;
+ PrimExpr identity_elem_;
+ };
+
+ // Identity element for reduction (assumed to be 0 for addition-based
reductions)
+ PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
+
+ // Substitute reduction buffer load with identity element
+ InitSubstituter init_subst(inlined_buffer_, identity_elem);
+ PrimExpr init_epilogue = init_subst(epilogue_expression_);
+
+ // Apply index mapping
+ init_epilogue = Substitute(init_epilogue, var_map);
+
+ // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj])
+ arith::Analyzer analyzer;
+ init_epilogue = analyzer.Simplify(init_epilogue);
+
+ BufferStore new_init_store = BufferStore(epilogue_output_buffer_,
init_epilogue,
+
Substitute(epilogue_output_indices_, var_map));
new_block->init = new_init_store;
- // 3. Replace output buffer from temp to D in body
- class BufferReplacer : public StmtExprMutator {
+ // 3. Generalized update transformation: apply epilogue expression with
reduction buffer replaced
+ // If reduction buffer load's parent is Add and other operand is not a
reduction buffer,
+ // remove that operand (bias addend) from update expression
+ class UpdateSubstituter : public StmtExprMutator {
public:
- BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type,
DataType dtype,
- PrimExpr clipping_lower = PrimExpr(), PrimExpr
clipping_upper = PrimExpr())
+ UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const
Buffer& reduction_buf,
+ const PrimExpr& epilogue_expr, const
std::unordered_map<Var, Var>& var_map)
: old_buffer_(old_buf),
new_buffer_(new_buf),
- epilogue_type_(epilogue_type),
- dtype_(dtype),
- clipping_lower_(clipping_lower),
- clipping_upper_(clipping_upper) {}
+ reduction_buffer_(reduction_buf),
+ epilogue_expression_(epilogue_expr),
+ var_map_(var_map) {}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
- PrimExpr new_value = store->value;
- // For ReLU, apply max per iteration to match per-iteration ReLU
semantics
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- PrimExpr zero = tir::make_zero(dtype_);
- new_value = Max(new_value, zero);
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, apply min(max(value, lower), upper) per iteration
- new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
- }
+ // Replace old_buffer_ in store->value with new_buffer_ to get the
reduction update
+ // expression This ensures store->value references new_buffer_ instead
of old_buffer_
+ class ReductionUpdateReplacer : public ExprMutator {
+ public:
+ ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf)
+ : old_buffer_(old_buf), new_buffer_(new_buf) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load =
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(old_buffer_)) {
+ return BufferLoad(new_buffer_, load->indices);
+ }
+ return load;
+ }
+
+ private:
+ Buffer old_buffer_;
+ Buffer new_buffer_;
+ };
+
+ ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_);
+ PrimExpr reduction_update = reduction_replacer(store->value);
+
+ // Generalized approach: apply epilogue expression with reduction
buffer load replaced
+ // If reduction buffer load's direct parent is Add and the other
operand is not a reduction
+ // buffer, remove that operand (bias addend) from the update expression
+ class GeneralizedEpilogueApplier : public ExprMutator {
+ public:
+ GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer&
reduction_buf,
+ const PrimExpr& replacement)
+ : target_buffer_(target_buf),
+ reduction_buffer_(reduction_buf),
+ replacement_(replacement),
+ found_target_load_(false) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load =
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ found_target_load_ = true;
+ // Check if parent is Add (will be checked in VisitExpr_(const
AddNode*))
+ return replacement_;
+ }
+ return load;
+ }
+
+ PrimExpr VisitExpr_(const AddNode* op) final {
+ // Visit children first to see if we find the target buffer load
+ bool found_before = found_target_load_;
+ found_target_load_ = false;
+
+ PrimExpr a = VisitExpr(op->a);
+ bool found_in_a = found_target_load_;
+ found_target_load_ = false;
+
+ PrimExpr b = VisitExpr(op->b);
+ bool found_in_b = found_target_load_;
+
+ // If target buffer load was found in this Add node
+ if (found_in_a || found_in_b) {
+ // Check if the other operand is NOT from the reduction buffer
+ // If so, it's likely a bias addend that should be removed in
update
+ bool other_is_reduction = false;
+ if (found_in_a) {
+ // Check if b is from reduction buffer
+ if (const auto* load_b = b.as<BufferLoadNode>()) {
+ other_is_reduction =
load_b->buffer.same_as(reduction_buffer_);
+ }
+ if (!other_is_reduction) {
+ // b is the bias addend, remove it
+ return a;
+ }
+ } else { // found_in_b
+ // Check if a is from reduction buffer
+ if (const auto* load_a = a.as<BufferLoadNode>()) {
+ other_is_reduction =
load_a->buffer.same_as(reduction_buffer_);
+ }
+ if (!other_is_reduction) {
+ // a is the bias addend, remove it
+ return b;
+ }
+ }
+ // If other operand is also from reduction buffer, keep the Add
+ return Add(a, b);
+ }
+
+ // Target buffer load not found in this Add, return as is
+ found_target_load_ = found_before;
+ return Add(a, b);
+ }
+
+ private:
+ const Buffer& target_buffer_;
+ const Buffer& reduction_buffer_;
+ const PrimExpr& replacement_;
+ bool found_target_load_;
+ };
+
+ GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_,
reduction_update);
+ PrimExpr new_value = applier(epilogue_expression_);
+
+ // Apply index mapping
+ new_value = Substitute(new_value, var_map_);
+
return BufferStore(new_buffer_, new_value, store->indices);
}
return store;
@@ -1371,19 +1436,16 @@ SBlock
ReductionEpilogueFuser::CreateFusedReductionBlock(
private:
Buffer old_buffer_;
Buffer new_buffer_;
- EpilogueType epilogue_type_;
- DataType dtype_;
- PrimExpr clipping_lower_;
- PrimExpr clipping_upper_;
+ Buffer reduction_buffer_;
+ PrimExpr epilogue_expression_;
+ std::unordered_map<Var, Var> var_map_;
};
- DataType dtype = epilogue_output_buffer_->dtype;
- PrimExpr clipping_lower_subst =
- epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_,
var_map) : PrimExpr();
- PrimExpr clipping_upper_subst =
- epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_,
var_map) : PrimExpr();
- BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_,
epilogue_type_, dtype,
- clipping_lower_subst, clipping_upper_subst);
+ // Apply index mapping to epilogue expression first
+ PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map);
+
+ UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_,
inlined_buffer_,
+ epilogue_expr_mapped, var_map);
new_block->body = replacer(reduction_block->body);
// 4. Update write regions
@@ -1398,21 +1460,22 @@ SBlock
ReductionEpilogueFuser::CreateFusedReductionBlock(
}
new_block->writes = new_writes;
- // 5. Update read regions (C first, then A, B)
+ // 5. Update read regions: add all buffers from epilogue expression (except
reduction buffer)
ffi::Array<BufferRegion> new_reads;
std::unordered_set<const BufferNode*> read_bufs;
- // Add C buffer read first (used in init)
- if (epilogue_addend_buffer_.defined()) {
- new_reads.push_back(BufferRegion(epilogue_addend_buffer_,
-
Substitute(epilogue_addend_region_->region, var_map)));
- read_bufs.insert(epilogue_addend_buffer_.get());
+ // Add all non-reduction buffers from epilogue expression
+ for (const BufferRegion& read : epilogue_block_->reads) {
+ if (!read->buffer.same_as(inlined_buffer_)) {
+ new_reads.push_back(BufferRegion(read->buffer, Substitute(read->region,
var_map)));
+ read_bufs.insert(read->buffer.get());
+ }
}
- // Add existing read regions (A, B, etc.)
+ // Add existing read regions from reduction block (A, B, etc.)
for (const BufferRegion& read : reduction_block->reads) {
if (!read->buffer.same_as(inlined_buffer_)) {
- // Only add non-temp buffers
+ // Only add non-temp buffers that haven't been added yet
if (read_bufs.find(read->buffer.get()) == read_bufs.end()) {
new_reads.push_back(read);
read_bufs.insert(read->buffer.get());
diff --git
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
index 7210237f83..88b1062621 100644
--- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
+++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
@@ -214,5 +214,64 @@ def test_fuse_reduction_epilogue_multiple_epilogue():
assert mod is not None
[email protected]_func
+def matmul_bias_invalid_multiple_use_before(
+ A: T.Buffer((16, 16), "int8"),
+ B: T.Buffer((16, 16), "int8"),
+ C1: T.Buffer((16, 16), "int32"),
+ C2: T.Buffer((16, 16), "int32"),
+ D: T.Buffer((16, 16), "int32"),
+) -> None:
+ """Epilogue uses the reduction result twice; fusion must be rejected."""
+ temp = T.alloc_buffer((16, 16), dtype="int32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.sblock("multiply"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.int32(0)
+ temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") *
T.cast(B[vj, vk], "int32")
+ for i, j in T.grid(16, 16):
+ with T.sblock("bad_epilogue"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ # temp[vi, vj] is used twice in the epilogue expression
+ D[vi, vj] = (temp[vi, vj] + C1[vi, vj]) * (temp[vi, vj] + C2[vi,
vj])
+
+
+def test_fuse_reduction_epilogue_reject_multiple_use():
+ """fusion should be rejected when the reduction result appears more than
once."""
+ sch = tir.Schedule(matmul_bias_invalid_multiple_use_before,
debug_mask="all")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.fuse_reduction_epilogue("multiply", "bad_epilogue")
+
+
[email protected]_func
+def matmul_bias_invalid_scaling_before(
+ A: T.Buffer((16, 16), "int8"),
+ B: T.Buffer((16, 16), "int8"),
+ C: T.Buffer((16, 16), "int32"),
+ D: T.Buffer((16, 16), "int32"),
+) -> None:
+ """Epilogue scales the reduction result; fusion must be rejected."""
+ temp = T.alloc_buffer((16, 16), dtype="int32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.sblock("multiply"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.int32(0)
+ temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") *
T.cast(B[vj, vk], "int32")
+ for i, j in T.grid(16, 16):
+ with T.sblock("scaled_epilogue"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ # temp[vi, vj] is scaled by 2 before adding bias; this must not be
fused.
+ D[vi, vj] = temp[vi, vj] * T.int32(2) + C[vi, vj]
+
+
+def test_fuse_reduction_epilogue_reject_scaling():
+ """fusion should be rejected when the reduction result is scaled by
Mul/Div/Mod."""
+ sch = tir.Schedule(matmul_bias_invalid_scaling_before, debug_mask="all")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.fuse_reduction_epilogue("multiply", "scaled_epilogue")
+
+
if __name__ == "__main__":
tvm.testing.main()