wrongtest-intellif commented on code in PR #18636:
URL: https://github.com/apache/tvm/pull/18636#discussion_r2710810193
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value based on epilogue type
- BufferStore new_init_store;
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
- PrimExpr init_value = Substitute(epilogue_addend_, var_map);
- PrimExpr zero = tir::make_zero(init_value.dtype());
- new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
- Substitute(epilogue_output_indices_,
var_map));
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, init should be min(max(init_value, lower), upper)
- // Since init is typically 0, this becomes min(max(0, lower), upper)
- PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
- PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
- Substitute(clipping_upper_, var_map));
- new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
- Substitute(epilogue_output_indices_,
var_map));
- } else {
- // Bias: D[vi, vj] = C[vi, vj]
- new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_,
var_map));
- }
+ // 2. Generalized init transformation: substitute reduction buffer load with
identity element (0)
+ // Create a substituter to replace reduction_buffer_load_ with identity
element
+ class InitSubstituter : public ExprMutator {
+ public:
+ InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+ : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ return identity_elem_;
+ }
+ return load;
+ }
+
+ private:
+ Buffer target_buffer_;
+ PrimExpr identity_elem_;
+ };
+
+ // Identity element for reduction (assumed to be 0 for addition-based
reductions)
+ PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
Review Comment:
The correctness matters most. The concrete strategy could be free.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]