This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4dc8a5d2f [BugFix][TIR] Fix incorrect optimization when lowering
floordiv and f… (#18694)
b4dc8a5d2f is described below
commit b4dc8a5d2f03c9bbb87cf814b00386127c0273fc
Author: guocj <[email protected]>
AuthorDate: Fri Jan 30 21:13:49 2026 +0800
[BugFix][TIR] Fix incorrect optimization when lowering floordiv and f…
(#18694)
…loormod
This patch fixes an issue in the LowerIntrin pass where incorrect
optimizations were applied to floordiv and floormod operations.
The root cause is that the pass attempts to find an equivalent
representation for floordiv(a, b) by calculating the expression (op->b -
1) - a_min. This expression, when subjected to constant folding, can
potentially overflow the range of int32 or int16. When this overflow
occurs, the transformation becomes invalid and no longer equivalent to
the original operation.
To fix this, we enhanced the condition under which the transformation is
applied. The new condition ensures that the transformation is only
performed when (b_max - a_min) is less than INT_MAX + 2. If this
condition is not met, the transformation is skipped and the common
lowering steps are followed to ensure correctness.
A regression test has been added to cover this case.
Fixes #18684
---
src/tir/transforms/lower_intrin.cc | 40 ++++++++++++++++------
.../test_tir_transform_lower_intrin.py | 12 +++++++
2 files changed, 42 insertions(+), 10 deletions(-)
diff --git a/src/tir/transforms/lower_intrin.cc
b/src/tir/transforms/lower_intrin.cc
index 4c35fdb290..6a7d2b2776 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -118,10 +118,20 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.
- arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
- if (const_int_bound->min_value < 0 &&
- const_int_bound->min_value >
-
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
+
+ // The optimization below rewrites expressions involving `-a_min + (b -
1)`.
+ // Without proper bounds checking, this expression may overflow the dtype
+ // maximum, leading to non-equivalent transformations.
+ // To ensure safety, we require:
+ // b_max - a_min <= max_value_of_dtype + 1
+ // This provides a conservative upper bound that prevents overflow and
+ // preserves the original semantics.
+ arith::ConstIntBound const_int_bound_a =
analyzer_->const_int_bound(op->a);
+ arith::ConstIntBound const_int_bound_b =
analyzer_->const_int_bound(op->b);
+ const int64_t max_value_of_dtype =
+ Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
+ if (const_int_bound_a->min_value < 0 &&
+ const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without
using
// negative operands.
//
@@ -152,7 +162,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
- IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
+ IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b *
ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
@@ -214,10 +224,20 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.
- arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
- if (const_int_bound->min_value < 0 &&
- const_int_bound->min_value >
-
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
+
+ // The optimization below rewrites expressions involving `-a_min + (b -
1)`.
+ // Without proper bounds checking, this expression may overflow the dtype
+ // maximum, leading to non-equivalent transformations.
+ // To ensure safety, we require:
+ // b_max - a_min <= max_value_of_dtype + 1
+ // This provides a conservative upper bound that prevents overflow and
+ // preserves the original semantics.
+ arith::ConstIntBound const_int_bound_a =
analyzer_->const_int_bound(op->a);
+ arith::ConstIntBound const_int_bound_b =
analyzer_->const_int_bound(op->b);
+ const int64_t max_value_of_dtype =
+ Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
+ if (const_int_bound_a->min_value < 0 &&
+ const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1) {
// The goal is to write floormod(a,b) in terms of truncdiv and
truncmod,
// without using negative operands.
//
@@ -247,7 +267,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
- IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
+ IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b *
ceildiv);
return truncmod(offset_numerator, op->b);
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index 864b24bc0f..63f37e6f41 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -92,6 +92,12 @@ def test_lower_floordiv():
# const power of two
res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8,
dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a,
b: a // b)
+ # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
+ res = lower_intrin(
+ [x, y],
+ tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
+ )
+ check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a,
b: (a + 4) // b)
@tvm.testing.requires_llvm
@@ -115,6 +121,12 @@ def test_lower_floormod():
# const power of two
res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8,
dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a,
b: a % b)
+ # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
+ res = lower_intrin(
+ [x, y],
+ tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
+ )
+ check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a,
b: (a + 4) % b)
if __name__ == "__main__":