wrongtest-intellif commented on code in PR #18636:
URL: https://github.com/apache/tvm/pull/18636#discussion_r2703562714
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value based on epilogue type
- BufferStore new_init_store;
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
- PrimExpr init_value = Substitute(epilogue_addend_, var_map);
- PrimExpr zero = tir::make_zero(init_value.dtype());
- new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
- Substitute(epilogue_output_indices_,
var_map));
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, init should be min(max(init_value, lower), upper)
- // Since init is typically 0, this becomes min(max(0, lower), upper)
- PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
- PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
- Substitute(clipping_upper_, var_map));
- new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
- Substitute(epilogue_output_indices_,
var_map));
- } else {
- // Bias: D[vi, vj] = C[vi, vj]
- new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_,
var_map));
- }
+ // 2. Generalized init transformation: substitute reduction buffer load with
identity element (0)
+ // Create a substituter to replace reduction_buffer_load_ with identity
element
+ class InitSubstituter : public ExprMutator {
+ public:
+ InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+ : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ return identity_elem_;
+ }
+ return load;
+ }
+
+ private:
+ Buffer target_buffer_;
+ PrimExpr identity_elem_;
+ };
+
+ // Identity element for reduction (assumed to be 0 for addition-based
reductions)
+ PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
+
+ // Substitute reduction buffer load with identity element
+ InitSubstituter init_subst(inlined_buffer_, identity_elem);
+ PrimExpr init_epilogue = init_subst(epilogue_expression_);
+
+ // Apply index mapping
+ init_epilogue = Substitute(init_epilogue, var_map);
+
+ // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj])
+ arith::Analyzer analyzer;
+ init_epilogue = analyzer.Simplify(init_epilogue);
+
+ BufferStore new_init_store = BufferStore(epilogue_output_buffer_,
init_epilogue,
+
Substitute(epilogue_output_indices_, var_map));
new_block->init = new_init_store;
- // 3. Replace output buffer from temp to D in body
- class BufferReplacer : public StmtExprMutator {
+ // 3. Generalized update transformation: apply epilogue expression with
reduction buffer replaced
+ // If reduction buffer load's parent is Add and other operand is not a
reduction buffer,
+ // remove that operand (bias addend) from update expression
+ class UpdateSubstituter : public StmtExprMutator {
public:
- BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type,
DataType dtype,
- PrimExpr clipping_lower = PrimExpr(), PrimExpr
clipping_upper = PrimExpr())
+ UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const
Buffer& reduction_buf,
+ const PrimExpr& epilogue_expr, const
std::unordered_map<Var, Var>& var_map)
: old_buffer_(old_buf),
new_buffer_(new_buf),
- epilogue_type_(epilogue_type),
- dtype_(dtype),
- clipping_lower_(clipping_lower),
- clipping_upper_(clipping_upper) {}
+ reduction_buffer_(reduction_buf),
+ epilogue_expression_(epilogue_expr),
+ var_map_(var_map) {}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
- PrimExpr new_value = store->value;
- // For ReLU, apply max per iteration to match per-iteration ReLU
semantics
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- PrimExpr zero = tir::make_zero(dtype_);
- new_value = Max(new_value, zero);
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, apply min(max(value, lower), upper) per iteration
- new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
- }
+ // Replace old_buffer_ in store->value with new_buffer_ to get the
reduction update
+ // expression This ensures store->value references new_buffer_ instead
of old_buffer_
+ class ReductionUpdateReplacer : public ExprMutator {
+ public:
+ ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf)
+ : old_buffer_(old_buf), new_buffer_(new_buf) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load =
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(old_buffer_)) {
+ return BufferLoad(new_buffer_, load->indices);
+ }
+ return load;
+ }
+
+ private:
+ Buffer old_buffer_;
+ Buffer new_buffer_;
+ };
+
+ ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_);
+ PrimExpr reduction_update = reduction_replacer(store->value);
+
+ // Generalized approach: apply epilogue expression with reduction
buffer load replaced
+ // If reduction buffer load's direct parent is Add and the other
operand is not a reduction
+ // buffer, remove that operand (bias addend) from the update expression
+ class GeneralizedEpilogueApplier : public ExprMutator {
+ public:
+ GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer&
reduction_buf,
+ const PrimExpr& replacement)
+ : target_buffer_(target_buf),
+ reduction_buffer_(reduction_buf),
+ replacement_(replacement),
+ found_target_load_(false) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load =
Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ found_target_load_ = true;
+ // Check if parent is Add (will be checked in VisitExpr_(const
AddNode*))
+ return replacement_;
+ }
+ return load;
+ }
+
+ PrimExpr VisitExpr_(const AddNode* op) final {
Review Comment:
Try ensure no incorrect patterns. Like `epilogue = relu(reduce_out * a +
C[i])`, `epilogue = (reduce_out + C[i]) * (reduce_out + C[i])`, etc.
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value based on epilogue type
- BufferStore new_init_store;
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
- PrimExpr init_value = Substitute(epilogue_addend_, var_map);
- PrimExpr zero = tir::make_zero(init_value.dtype());
- new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
- Substitute(epilogue_output_indices_,
var_map));
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, init should be min(max(init_value, lower), upper)
- // Since init is typically 0, this becomes min(max(0, lower), upper)
- PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
- PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
- Substitute(clipping_upper_, var_map));
- new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
- Substitute(epilogue_output_indices_,
var_map));
- } else {
- // Bias: D[vi, vj] = C[vi, vj]
- new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_,
var_map));
- }
+ // 2. Generalized init transformation: substitute reduction buffer load with
identity element (0)
+ // Create a substituter to replace reduction_buffer_load_ with identity
element
+ class InitSubstituter : public ExprMutator {
+ public:
+ InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+ : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ return identity_elem_;
+ }
+ return load;
+ }
+
+ private:
+ Buffer target_buffer_;
+ PrimExpr identity_elem_;
+ };
+
+ // Identity element for reduction (assumed to be 0 for addition-based
reductions)
+ PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
Review Comment:
Generally, we require `epilogue(zero ⊕ x0 ⊕ ... ⊕ xn) == g(xn, g(xn-1,
g(.....g(x0, epilogue_init(zero)))...)` with `g` and `epilogue_init` deduce
from epilogue expression, right?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]