kimm240 commented on code in PR #18636:
URL: https://github.com/apache/tvm/pull/18636#discussion_r2703900896
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value based on epilogue type
- BufferStore new_init_store;
- if (epilogue_type_ == EpilogueType::BiasReLU) {
- // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
- PrimExpr init_value = Substitute(epilogue_addend_, var_map);
- PrimExpr zero = tir::make_zero(init_value.dtype());
- new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
- Substitute(epilogue_output_indices_,
var_map));
- } else if (epilogue_type_ == EpilogueType::Clipping) {
- // For Clipping, init should be min(max(init_value, lower), upper)
- // Since init is typically 0, this becomes min(max(0, lower), upper)
- PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
- PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
- Substitute(clipping_upper_, var_map));
- new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
- Substitute(epilogue_output_indices_,
var_map));
- } else {
- // Bias: D[vi, vj] = C[vi, vj]
- new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_,
var_map));
- }
+ // 2. Generalized init transformation: substitute reduction buffer load with
identity element (0)
+ // Create a substituter to replace reduction_buffer_load_ with identity
element
+ class InitSubstituter : public ExprMutator {
+ public:
+ InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+ : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+ if (load->buffer.same_as(target_buffer_)) {
+ return identity_elem_;
+ }
+ return load;
+ }
+
+ private:
+ Buffer target_buffer_;
+ PrimExpr identity_elem_;
+ };
+
+ // Identity element for reduction (assumed to be 0 for addition-based
reductions)
+ PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);
Review Comment:
For `BiasAdd`, the transformation $epilogue(\sum x) = \sum x + Bias$ holds
regardless of whether the bias is handled at the initialization stage or the
end.
However, for non-linear operations like `ReLU` and `Clipping`, this identity
does not hold if we simply substitute the identity element in the `Init` block,
as $f(\sum x) \neq f(0) + \sum x$.
I initially aimed to generalize the logic by processing these operations
through a unified substitution mechanism (essentially treating them in a
'per-iteration') to maintain the existing framework's structure-what is merged
in #18515.
Should we strictly move these non-linear transformations to the final Store
stage (applied once to the final sum)?
Or, since the previous 'per-iteration' behavior was already merged and used,
should we keep it as an option or a specific 'fused-update' mode? I'm eager to
hear about it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]