This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8521c2f7f3 [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern
support (#18515)
8521c2f7f3 is described below
commit 8521c2f7f3a9db5393ddf37761264f1657d15206
Author: kimm240 <[email protected]>
AuthorDate: Thu Dec 25 21:27:23 2025 +0900
[TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support (#18515)
Currently, the FuseReductionEpilogue primitive only supports Bias
(addition) and BiasReLU (addition + ReLU) epilogue patterns. However,
clipping operations (min(max(x, lower), upper)) are commonly used in
deep learning models and would benefit from the same fusion
optimization.
This commit extends FuseReductionEpilogue to support Clipping patterns
by:
1. Adding EpilogueType::Clipping to the enum to distinguish clipping
patterns from other epilogue types.
2. Adding clipping_lower_ and clipping_upper_ members to
ReductionEpilogueFuser to store clipping bounds extracted from the
epilogue pattern.
3. Extending AnalyzeEpiloguePattern to detect clipping patterns:
- min(max(temp, lower), upper)
- max(min(temp, upper), lower)
- All commutative variants of min/max at each level
4. Updating BiasReLU pattern matching to handle max(0, x) form in
addition to max(x, 0) for better commutativity support.
5. Modifying CreateFusedReductionBlock to apply clipping to the init
value: init = min(max(0, lower), upper)
6. Updating BufferReplacer to apply clipping per-iteration:
value = min(max(value, lower), upper)
7. Adding validation in BodyPatternAllowFusion to ensure temp appears
exactly once in clipping patterns.
8. Creating comprehensive test coverage with 8 test cases:
- Basic fusion test
- Numerical correctness verification
- Multiple epilogue blocks test
- 5 commutative variant tests
This implementation follows the same per-iteration semantics as
BiasReLU,
where clipping is applied at each reduction step rather than
post-reduction. This semantic change is documented in the docstring with
a warning about potential numerical differences.
The test suite verifies that all commutative forms of clipping patterns
are correctly recognized and that the fused implementation produces
numerically identical results to the per-iteration reference
implementation.
---------
Co-authored-by: hyun gyu kim <[email protected]>
---
.gitignore | 3 +
python/tvm/tir/schedule/schedule.py | 31 ++-
src/tir/schedule/primitive/compute_inline.cc | 224 +++++++++++++++--
...ir_schedule_fuse_reduction_epilogue_clipping.py | 271 +++++++++++++++++++++
...st_tir_schedule_fuse_reduction_epilogue_relu.py | 229 +++++++++++++++++
5 files changed, 741 insertions(+), 17 deletions(-)
diff --git a/.gitignore b/.gitignore
index 5bcbd5e373..6fa10a5e76 100644
--- a/.gitignore
+++ b/.gitignore
@@ -274,3 +274,6 @@ tvm-site/
# GDB history file
.gdb_history
+
+# Less command history file
+.lesshst
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 0d41ffe943..b1e1a3f5d5 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2356,14 +2356,41 @@ class Schedule(Object):
It requires:
1) The reduction block is a complete reduction block
2) The epilogue block only reads from the reduction block's output
- 3) The epilogue performs a simple addition: output = reduction_result
+ bias
+ 3) The epilogue matches one of the supported patterns:
+ - Bias: ``output = reduction_result + bias``
+ - BiasReLU: ``output = max(reduction_result + bias, 0)``
+ - Clipping: ``output = min(max(reduction_result, lower), upper)``
+ or their commutative variants
+
+ .. warning::
+
+ **Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):**
+
+ For non-linear epilogues (BiasReLU and Clipping), fusion changes
the
+ computation semantics from post-reduction application to
per-iteration
+ application. This can lead to different numerical results.
+
+ **Example with Clipping to [-5, 5] and inputs [6, -2]:**
+
+ - **Post-reduction clipping** (original): ``clip(sum([6, -2])) =
clip(4) = 4``
+ - **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 →
clip(5+(-2))=3``
+
+ The fused version applies clipping at each reduction iteration,
which
+ may be an intended optimization for some models but can cause
unexpected
+ correctness issues if users are not aware of this behavior.
+
+ For linear epilogues (Bias), fusion preserves exact numerical
equivalence.
Parameters
----------
reduction_block : Union[BlockRV, str]
The reduction block (e.g., matmul)
epilogue_block : Union[BlockRV, str]
- The epilogue block to be fused (e.g., bias add)
+ The epilogue block to be fused (e.g., bias add, ReLU, clipping)
+
+ Examples
+ --------
+ See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples.
"""
reduction_block = self._normalize_block_arg(reduction_block)
epilogue_block = self._normalize_block_arg(epilogue_block)
diff --git a/src/tir/schedule/primitive/compute_inline.cc
b/src/tir/schedule/primitive/compute_inline.cc
index cc3785d5c1..0ab6d7e2b6 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -988,6 +988,13 @@ void ReverseComputeInline(ScheduleState self, const
StmtSRef& consumer_block_sre
* \brief Helper to fuse epilogue block into reduction block
* Analyzes epilogue pattern and transforms reduction init/update
*/
+// Epilogue type enumeration
+enum class EpilogueType {
+ Bias, // temp + C
+ BiasReLU, // max(temp + C, 0)
+ Clipping, // min(max(temp, lower), upper)
+};
+
class ReductionEpilogueFuser : public BaseInliner {
public:
explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const
BlockNode* reduction_block,
@@ -995,7 +1002,19 @@ class ReductionEpilogueFuser : public BaseInliner {
const StmtSRef& scope_root_sref)
: BaseInliner(reduction_buffer, epilogue_block_realize->block,
scope_root_sref),
reduction_block_(reduction_block),
- epilogue_block_(epilogue_block_realize->block.get()) {}
+ epilogue_block_(epilogue_block_realize->block.get()),
+ epilogue_type_(EpilogueType::Bias) {
+ // Disable opaque access check for epilogue fusion
+ // Epilogue blocks can read multiple buffers (temp + bias), which is
allowed
+ has_opaque_access = false;
+ }
+
+ // Override CheckOpaqueAccess to allow multiple buffer reads
+ void CheckOpaqueAccess(const VarNode* buffer_var) {
+ // For epilogue fusion, we allow multiple buffer reads (temp + bias)
+ // So we don't check for opaque access
+ // BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class
+ }
bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
@@ -1012,18 +1031,21 @@ class ReductionEpilogueFuser : public BaseInliner {
const
BufferStoreNode* from) {
struct Extractor : public ExprVisitor {
void VisitExpr_(const BufferLoadNode* load) final {
- if (load->buffer.get() == buffer) {
+ if (load->buffer.same_as(buffer)) {
result.push_back(load);
}
+ // Continue visiting child nodes (indices)
ExprVisitor::VisitExpr_(load);
}
- const BufferNode* buffer;
+ Buffer buffer;
std::vector<const BufferLoadNode*> result;
} extractor;
- extractor.buffer = buffer.get();
+ extractor.buffer = buffer;
+ // Visit indices first (though they typically don't contain BufferLoad)
for (const PrimExpr& expr : from->indices) {
extractor(expr);
}
+ // Visit the value expression (e.g., max(temp + C, 0) for ReLU)
extractor(from->value);
return std::move(extractor.result);
}
@@ -1036,6 +1058,9 @@ class ReductionEpilogueFuser : public BaseInliner {
BufferRegion epilogue_output_region_{nullptr}; // Write region of D
Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
+ EpilogueType epilogue_type_; // Type of epilogue
operation
+ PrimExpr clipping_lower_{nullptr}; // Lower bound for
clipping
+ PrimExpr clipping_upper_{nullptr}; // Upper bound for
clipping
};
bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize&
epilogue_block_realize) {
@@ -1058,26 +1083,36 @@ bool
ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
return false;
}
- // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+ // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
+ // D[i,j] = min(max(temp[i,j], lower), upper)
if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
- // Failure: epilogue is not a simple addition pattern
+ // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or
Clipping)
+ return false;
+ }
+
+ // 5. Verify temp appears exactly once in the epilogue pattern
+ // This ensures correctness for all supported patterns (Bias, BiasReLU,
Clipping)
+ // The reduction result buffer must be used exactly once in the epilogue
expression
+ if (loads.size() != 1) {
+ // Failure: The reduction result (temp) must be used exactly once in the
+ // epilogue expression for fusion.
return false;
}
- // 5. Check if producer is a reduction block
+ // 6. Check if producer is a reduction block
if (!IsReductionBlock(reduction_block_)) {
// Failure: producer is not a reduction block
return false;
}
- // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+ // 7. Extract epilogue information (output buffer, indices, regions, etc.)
ExtractEpilogueInfo();
return true;
}
bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
- // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+ // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
if (const auto* add = value.as<AddNode>()) {
const auto* load_a = add->a.as<BufferLoadNode>();
const auto* load_b = add->b.as<BufferLoadNode>();
@@ -1088,10 +1123,125 @@ bool
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
+ epilogue_type_ = EpilogueType::Bias;
return true;
}
}
+ // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j],
upper), lower) (Clipping)
+ // Handle all commutative variants of min/max at each level.
+
+ // Helper to check if an expression is a load from the reduction buffer, and
+ // return the other operand as `other` if so.
+ auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const
PrimExpr& b,
+ PrimExpr* other) -> bool {
+ if (const auto* load_a = a.as<BufferLoadNode>()) {
+ if (load_a->buffer.same_as(inlined_buffer_)) {
+ *other = b;
+ return true;
+ }
+ }
+ if (const auto* load_b = b.as<BufferLoadNode>()) {
+ if (load_b->buffer.same_as(inlined_buffer_)) {
+ *other = a;
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Check for min(max(temp, lower), upper) and commutative variants
+ if (const auto* min_node = value.as<MinNode>()) {
+ const MaxNode* max_node = nullptr;
+ PrimExpr upper;
+ // Try both (a, b) as possible positions of the inner max
+ if ((max_node = min_node->a.as<MaxNode>())) {
+ upper = min_node->b;
+ } else if ((max_node = min_node->b.as<MaxNode>())) {
+ upper = min_node->a;
+ }
+ if (max_node != nullptr) {
+ PrimExpr lower;
+ if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
+ clipping_lower_ = lower;
+ clipping_upper_ = upper;
+ epilogue_type_ = EpilogueType::Clipping;
+ return true;
+ }
+ }
+ }
+
+ // Check for max(min(temp[i,j], upper), lower) and commutative variants
+ if (const auto* max_node = value.as<MaxNode>()) {
+ const MinNode* min_node = nullptr;
+ PrimExpr lower;
+ // Try both (a, b) as possible positions of the inner min
+ if ((min_node = max_node->a.as<MinNode>())) {
+ lower = max_node->b;
+ } else if ((min_node = max_node->b.as<MinNode>())) {
+ lower = max_node->a;
+ }
+ if (min_node != nullptr) {
+ PrimExpr upper;
+ if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
+ clipping_lower_ = lower;
+ clipping_upper_ = upper;
+ epilogue_type_ = EpilogueType::Clipping;
+ return true;
+ }
+ }
+ }
+
+ // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0)
(BiasReLU)
+ // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
+ if (const auto* max_node = value.as<MaxNode>()) {
+ // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
+ // Support both integer and float zero constants.
+ const PrimExpr* add_candidate = nullptr;
+ bool is_zero_const = false;
+ auto is_zero_expr = [](const PrimExpr& expr) -> bool {
+ if (tir::is_zero(expr)) {
+ return true;
+ }
+ if (const auto* float_imm = expr.as<FloatImmNode>()) {
+ return float_imm->value == 0.0;
+ }
+ return false;
+ };
+
+ if (is_zero_expr(max_node->a)) {
+ is_zero_const = true;
+ add_candidate = &max_node->b;
+ } else if (is_zero_expr(max_node->b)) {
+ is_zero_const = true;
+ add_candidate = &max_node->a;
+ }
+
+ if (is_zero_const && add_candidate != nullptr) {
+ if (const auto* add = add_candidate->as<AddNode>()) {
+ const auto* load_a = add->a.as<BufferLoadNode>();
+ const auto* load_b = add->b.as<BufferLoadNode>();
+
+ bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
+ bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+
+ // Ensure exactly one operand is from the reduction buffer
+ if (a_is_target != b_is_target) {
+ epilogue_addend_ = a_is_target ? add->b : add->a;
+ epilogue_type_ = EpilogueType::BiasReLU;
+ return true;
+ }
+ } else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
+ // Handle bias-free ReLU: max(temp, 0) or max(0, temp)
+ if (load->buffer.same_as(inlined_buffer_)) {
+ epilogue_addend_ = tir::make_zero(load->dtype);
+ epilogue_type_ = EpilogueType::BiasReLU;
+ return true;
+ }
+ }
+ }
+ }
+
return false;
}
@@ -1158,20 +1308,54 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
- BufferStore new_init_store(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_, var_map));
+ // 2. Change init to epilogue value based on epilogue type
+ BufferStore new_init_store;
+ if (epilogue_type_ == EpilogueType::BiasReLU) {
+ // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
+ PrimExpr init_value = Substitute(epilogue_addend_, var_map);
+ PrimExpr zero = tir::make_zero(init_value.dtype());
+ new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
+ Substitute(epilogue_output_indices_,
var_map));
+ } else if (epilogue_type_ == EpilogueType::Clipping) {
+ // For Clipping, init should be min(max(init_value, lower), upper)
+ // Since init is typically 0, this becomes min(max(0, lower), upper)
+ PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
+ PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
+ Substitute(clipping_upper_, var_map));
+ new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
+ Substitute(epilogue_output_indices_,
var_map));
+ } else {
+ // Bias: D[vi, vj] = C[vi, vj]
+ new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
+ Substitute(epilogue_output_indices_,
var_map));
+ }
new_block->init = new_init_store;
// 3. Replace output buffer from temp to D in body
class BufferReplacer : public StmtExprMutator {
public:
- BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf),
new_buffer_(new_buf) {}
+ BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type,
DataType dtype,
+ PrimExpr clipping_lower = PrimExpr(), PrimExpr
clipping_upper = PrimExpr())
+ : old_buffer_(old_buf),
+ new_buffer_(new_buf),
+ epilogue_type_(epilogue_type),
+ dtype_(dtype),
+ clipping_lower_(clipping_lower),
+ clipping_upper_(clipping_upper) {}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
- return BufferStore(new_buffer_, store->value, store->indices);
+ PrimExpr new_value = store->value;
+ // For ReLU, apply max per iteration to match per-iteration ReLU
semantics
+ if (epilogue_type_ == EpilogueType::BiasReLU) {
+ PrimExpr zero = tir::make_zero(dtype_);
+ new_value = Max(new_value, zero);
+ } else if (epilogue_type_ == EpilogueType::Clipping) {
+ // For Clipping, apply min(max(value, lower), upper) per iteration
+ new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
+ }
+ return BufferStore(new_buffer_, new_value, store->indices);
}
return store;
}
@@ -1187,9 +1371,19 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
private:
Buffer old_buffer_;
Buffer new_buffer_;
+ EpilogueType epilogue_type_;
+ DataType dtype_;
+ PrimExpr clipping_lower_;
+ PrimExpr clipping_upper_;
};
- BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_);
+ DataType dtype = epilogue_output_buffer_->dtype;
+ PrimExpr clipping_lower_subst =
+ epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_,
var_map) : PrimExpr();
+ PrimExpr clipping_upper_subst =
+ epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_,
var_map) : PrimExpr();
+ BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_,
epilogue_type_, dtype,
+ clipping_lower_subst, clipping_upper_subst);
new_block->body = replacer(reduction_block->body);
// 4. Update write regions
diff --git
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
new file mode 100644
index 0000000000..6b3338b9a1
--- /dev/null
+++
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import (
+ verify_trace_roundtrip,
+ assert_structural_equal_ignore_global_symbol,
+)
+import numpy as np
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]_func
+def matmul_clipping_before(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ lower: T.float32,
+ upper: T.float32,
+) -> None:
+ """Original function with separate reduction and clipping epilogue
blocks."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.float32(0)
+ temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ for i, j in T.grid(16, 16):
+ with T.block("clipping"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper)
+
+
[email protected]_func
+def matmul_clipping_expected(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ lower: T.float32,
+ upper: T.float32,
+) -> None:
+ """Expected function after fusion (Clipping)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[vi, vk], B[vj, vk])
+ T.writes(D[vi, vj])
+ with T.init():
+ D[vi, vj] = T.min(T.max(T.float32(0), lower), upper)
+ D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower),
upper)
+
+
+def test_matmul_clipping():
+ """Test fusion of matmul with clipping epilogue."""
+ sch = tir.Schedule(matmul_clipping_before, debug_mask="all")
+ sch.fuse_reduction_epilogue("matmul", "clipping")
+ assert_structural_equal_ignore_global_symbol(sch.mod["main"],
matmul_clipping_expected)
+ verify_trace_roundtrip(sch=sch, mod=matmul_clipping_before)
+
+
[email protected]_func
+def matmul_clipping_before_per_iteration(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Original function with per-iteration clipping (same semantics as
fused)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ lower = T.float32(-5.0)
+ upper = T.float32(5.0)
+ for i, j in T.grid(16, 16):
+ with T.block("init"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ temp[vi, vj] = T.min(T.max(T.float32(0), lower), upper) # Clip
init
+
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ # Per-iteration clipping
+ temp[vi, vj] = T.min(T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk],
lower), upper)
+
+ for i, j in T.grid(16, 16):
+ with T.block("copy"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = temp[vi, vj]
+
+
+def test_matmul_clipping_correctness_unified():
+ """Test that original and fused produce identical results with
per-iteration clipping."""
+ A_np = np.random.randn(16, 16).astype("float32")
+ B_np = np.random.randn(16, 16).astype("float32")
+ lower = -5.0
+ upper = 5.0
+
+ # NumPy reference for per-iteration clipping
+ D_ref = np.clip(0.0, lower, upper) # init with clipping
+ for k in range(16):
+ D_ref = np.clip(D_ref + np.outer(A_np[:, k], B_np[:, k]), lower, upper)
+
+ # TVM execution (original with per-iteration clipping)
+ mod_original = tvm.compile(matmul_clipping_before_per_iteration,
target="llvm")
+ D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+ mod_original(
+ tvm.runtime.tensor(A_np),
+ tvm.runtime.tensor(B_np),
+ D_original_tvm,
+ )
+
+ # TVM execution (fused)
+ sch = tir.Schedule(matmul_clipping_before)
+ sch.fuse_reduction_epilogue("matmul", "clipping")
+ mod_fused = tvm.compile(sch.mod["main"], target="llvm")
+ D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+ # Pass scalar values directly as Python floats
+ mod_fused(
+ tvm.runtime.tensor(A_np),
+ tvm.runtime.tensor(B_np),
+ D_fused_tvm,
+ lower,
+ upper,
+ )
+
+ D_original = D_original_tvm.numpy()
+ D_fused = D_fused_tvm.numpy()
+
+ # Now both should match exactly
+ np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6)
+ np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6)
+ np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6)
+
+
[email protected]_func
+def matmul_clipping_multiple_epilogue_before(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ E: T.Buffer((16, 16), "float32"),
+ lower: T.float32,
+ upper: T.float32,
+) -> None:
+ """Original function with separate reduction and multiple epilogue blocks
(one with clipping, one without)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.float32(0)
+ temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ for i, j in T.grid(16, 16):
+ with T.block("clipping"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper)
+
+ for i, j in T.grid(16, 16):
+ with T.block("copy"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ E[vi, vj] = temp[vi, vj]
+
+
[email protected]_func
+def matmul_clipping_multiple_epilogue_expected(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ E: T.Buffer((16, 16), "float32"),
+ lower: T.float32,
+ upper: T.float32,
+) -> None:
+ """Expected function after fusion (Clipping) with multiple epilogue
blocks."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[vi, vk], B[vj, vk])
+ T.writes(D[vi, vj])
+ with T.init():
+ D[vi, vj] = T.min(T.max(T.float32(0), lower), upper)
+ D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower),
upper)
+ for i, j in T.grid(16, 16):
+ with T.block("copy"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(temp[vi, vj])
+ T.writes(E[vi, vj])
+ E[vi, vj] = temp[vi, vj]
+
+
+def test_matmul_clipping_multiple_epilogue():
+ """Test fusion with multiple epilogue blocks - one with clipping, one
without.
+
+ Following the same pattern as
test_fuse_reduction_epilogue_multiple_epilogue,
+ this test verifies that fusion works correctly when there are multiple
+ epilogue blocks. The temp buffer is kept because the second epilogue block
+ still needs it.
+ """
+ sch = tir.Schedule(matmul_clipping_multiple_epilogue_before,
debug_mask="all")
+ sch.fuse_reduction_epilogue("matmul", "clipping")
+ assert_structural_equal_ignore_global_symbol(
+ sch.mod["main"], matmul_clipping_multiple_epilogue_expected
+ )
+ verify_trace_roundtrip(sch=sch,
mod=matmul_clipping_multiple_epilogue_before)
+
+ mod = tvm.compile(sch.mod["main"], target="llvm")
+ assert mod is not None
+
+
+# Test commutative variants of clipping patterns
[email protected](
+ "pattern_func",
+ [
+ lambda temp, lower, upper: T.min(T.max(temp, lower), upper), #
min(max(temp, lower), upper)
+ lambda temp, lower, upper: T.min(upper, T.max(temp, lower)), #
min(upper, max(temp, lower))
+ lambda temp, lower, upper: T.min(T.max(lower, temp), upper), #
min(max(lower, temp), upper)
+ lambda temp, lower, upper: T.max(T.min(temp, upper), lower), #
max(min(temp, upper), lower)
+ lambda temp, lower, upper: T.max(lower, T.min(temp, upper)), #
max(lower, min(temp, upper))
+ ],
+)
+def test_matmul_clipping_commutative_variants(pattern_func):
+ """Test that all commutative variants of clipping patterns are
recognized."""
+ lower = -5.0
+ upper = 5.0
+
+ @T.prim_func
+ def test_func(
+ A: T.Buffer((8, 8), "float32"),
+ B: T.Buffer((8, 8), "float32"),
+ D: T.Buffer((8, 8), "float32"),
+ ) -> None:
+ temp = T.alloc_buffer((8, 8), dtype="float32")
+ for i, j, k in T.grid(8, 8, 8):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.float32(0)
+ temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ for i, j in T.grid(8, 8):
+ with T.block("clipping"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = pattern_func(temp[vi, vj], T.float32(lower),
T.float32(upper))
+
+ sch = tir.Schedule(test_func, debug_mask="all")
+ # Should not raise an error - all variants should be recognized
+ sch.fuse_reduction_epilogue("matmul", "clipping")
+ verify_trace_roundtrip(sch=sch, mod=test_func)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py
new file mode 100644
index 0000000000..66e5e52e43
--- /dev/null
+++
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py
@@ -0,0 +1,229 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import (
+ verify_trace_roundtrip,
+ assert_structural_equal_ignore_global_symbol,
+)
+import numpy as np
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]_func
+def matmul_bias_relu_before(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Original function with separate reduction and epilogue blocks (Bias +
ReLU)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.float32(0)
+ temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ for i, j in T.grid(16, 16):
+ with T.block("bias_relu"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0))
+
+
[email protected]_func
+def matmul_bias_relu_before_per_iteration(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Original function with per-iteration ReLU (same semantics as fused)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j in T.grid(16, 16):
+ with T.block("init"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ temp[vi, vj] = T.max(C[vi, vj], T.float32(0)) # ReLU on bias
+
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ # Per-iteration ReLU
+ temp[vi, vj] = T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk],
T.float32(0))
+
+ for i, j in T.grid(16, 16):
+ with T.block("copy"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = temp[vi, vj]
+
+
[email protected]_func
+def matmul_bias_relu_expected(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Expected function after fusion (Bias + ReLU)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+ T.writes(D[vi, vj])
+ with T.init():
+ D[vi, vj] = T.max(C[vi, vj], T.float32(0))
+ D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0))
+
+
+def test_matmul_bias_relu():
+ """Test fusion of matmul with bias + ReLU epilogue."""
+ sch = tir.Schedule(matmul_bias_relu_before, debug_mask="all")
+ sch.fuse_reduction_epilogue("matmul", "bias_relu")
+ assert_structural_equal_ignore_global_symbol(sch.mod["main"],
matmul_bias_relu_expected)
+ verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_before)
+
+
+def test_matmul_bias_relu_correctness_unified():
+ """Test that original and fused produce identical results with
per-iteration ReLU."""
+ A_np = np.random.randn(16, 16).astype("float32")
+ B_np = np.random.randn(16, 16).astype("float32")
+ C_np = np.random.randn(16, 16).astype("float32")
+
+ # NumPy reference for per-iteration ReLU
+ # Simulate per-iteration ReLU behavior
+ # Original code computes A[vi, vk] * B[vj, vk] which is A[i, k] * B[j, k]
+ # For each k: add outer product of A[:, k] and B[:, k]
+ D_ref = np.maximum(C_np, 0) # init with ReLU on bias
+ for k in range(16):
+ # A[:, k] is shape (16,), B[:, k] is shape (16,)
+ # Outer product: A[:, k] * B[:, k] for all i, j = A[i, k] * B[j, k]
+ # Using broadcasting: A[:, k:k+1] * B[:, k:k+1].T gives (16, 1) * (1,
16) = (16, 16)
+ D_ref = np.maximum(D_ref + np.outer(A_np[:, k], B_np[:, k]), 0)
+
+ # TVM execution (original with per-iteration ReLU)
+ mod_original = tvm.compile(matmul_bias_relu_before_per_iteration,
target="llvm")
+ D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+ mod_original(
+ tvm.runtime.tensor(A_np),
+ tvm.runtime.tensor(B_np),
+ tvm.runtime.tensor(C_np),
+ D_original_tvm,
+ )
+
+ # TVM execution (fused)
+ sch = tir.Schedule(matmul_bias_relu_before)
+ sch.fuse_reduction_epilogue("matmul", "bias_relu")
+ mod_fused = tvm.compile(sch.mod["main"], target="llvm")
+ D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+ mod_fused(
+ tvm.runtime.tensor(A_np),
+ tvm.runtime.tensor(B_np),
+ tvm.runtime.tensor(C_np),
+ D_fused_tvm,
+ )
+
+ D_original = D_original_tvm.numpy()
+ D_fused = D_fused_tvm.numpy()
+
+ # Now both should match exactly
+ np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6)
+ np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6)
+ np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6)
+
+
[email protected]_func
+def matmul_bias_relu_multiple_epilogue_before(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ E: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Original function with separate reduction and multiple epilogue blocks
(one with ReLU, one without)."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ temp[vi, vj] = T.float32(0)
+ temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ for i, j in T.grid(16, 16):
+ with T.block("bias_relu"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0))
+
+ for i, j in T.grid(16, 16):
+ with T.block("bias"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ E[vi, vj] = temp[vi, vj] + C[vi, vj]
+
+
[email protected]_func
+def matmul_bias_relu_multiple_epilogue_expected(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ D: T.Buffer((16, 16), "float32"),
+ E: T.Buffer((16, 16), "float32"),
+) -> None:
+ """Expected function after fusion (Bias + ReLU) with multiple epilogue
blocks."""
+ temp = T.alloc_buffer((16, 16), dtype="float32")
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+ T.writes(D[vi, vj])
+ with T.init():
+ D[vi, vj] = T.max(C[vi, vj], T.float32(0))
+ D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0))
+ for i, j in T.grid(16, 16):
+ with T.block("bias"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(temp[vi, vj], C[vi, vj])
+ T.writes(E[vi, vj])
+ E[vi, vj] = temp[vi, vj] + C[vi, vj]
+
+
+def test_matmul_bias_relu_multiple_epilogue():
+ """Test fusion with multiple epilogue blocks - one with ReLU, one without.
+
+ Following the same pattern as
test_fuse_reduction_epilogue_multiple_epilogue,
+ this test verifies that fusion works correctly when there are multiple
+ epilogue blocks. The temp buffer is kept because the second epilogue block
+ still needs it.
+ """
+ sch = tir.Schedule(matmul_bias_relu_multiple_epilogue_before,
debug_mask="all")
+ sch.fuse_reduction_epilogue("matmul", "bias_relu")
+ assert_structural_equal_ignore_global_symbol(
+ sch.mod["main"], matmul_bias_relu_multiple_epilogue_expected
+ )
+ verify_trace_roundtrip(sch=sch,
mod=matmul_bias_relu_multiple_epilogue_before)
+
+ mod = tvm.compile(sch.mod["main"], target="llvm")
+ assert mod is not None
+
+
+if __name__ == "__main__":
+ tvm.testing.main()