This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4dc8a5d2f [BugFix][TIR] Fix incorrect optimization when lowering 
floordiv and f… (#18694)
b4dc8a5d2f is described below

commit b4dc8a5d2f03c9bbb87cf814b00386127c0273fc
Author: guocj <[email protected]>
AuthorDate: Fri Jan 30 21:13:49 2026 +0800

    [BugFix][TIR] Fix incorrect optimization when lowering floordiv and f… 
(#18694)
    
    …loormod
    
    This patch fixes an issue in the LowerIntrin pass where incorrect
    optimizations were applied to floordiv and floormod operations.
    
    The root cause is that the pass attempts to find an equivalent
    representation for floordiv(a, b) by calculating the expression (op->b -
    1) - a_min. This expression, when subjected to constant folding, can
    potentially overflow the range of int32 or int16. When this overflow
    occurs, the transformation becomes invalid and no longer equivalent to
    the original operation.
    
    To fix this, we enhanced the condition under which the transformation is
    applied. The new condition ensures that the transformation is only
    performed when (b_max - a_min) is less than INT_MAX + 2. If this
    condition is not met, the transformation is skipped and the common
    lowering steps are followed to ensure correctness.
    
    A regression test has been added to cover this case.
    
    Fixes #18684
---
 src/tir/transforms/lower_intrin.cc                 | 40 ++++++++++++++++------
 .../test_tir_transform_lower_intrin.py             | 12 +++++++
 2 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/src/tir/transforms/lower_intrin.cc 
b/src/tir/transforms/lower_intrin.cc
index 4c35fdb290..6a7d2b2776 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -118,10 +118,20 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
       // If the numerator's lower bound is known, express the floordiv
       // in terms of truncdiv using only positive operands.
-      arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
-      if (const_int_bound->min_value < 0 &&
-          const_int_bound->min_value >
-              
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
+
+      // The optimization below rewrites expressions involving `-a_min + (b - 
1)`.
+      // Without proper bounds checking, this expression may overflow the dtype
+      // maximum, leading to non-equivalent transformations.
+      // To ensure safety, we require:
+      //   b_max - a_min <= max_value_of_dtype + 1
+      // This provides a conservative upper bound that prevents overflow and
+      // preserves the original semantics.
+      arith::ConstIntBound const_int_bound_a = 
analyzer_->const_int_bound(op->a);
+      arith::ConstIntBound const_int_bound_b = 
analyzer_->const_int_bound(op->b);
+      const int64_t max_value_of_dtype =
+          Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
+      if (const_int_bound_a->min_value < 0 &&
+          const_int_bound_b->max_value - const_int_bound_a->min_value <= 
max_value_of_dtype + 1) {
         // The goal is to write floordiv(a,b) in terms of truncdiv, without 
using
         // negative operands.
         //
@@ -152,7 +162,7 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
         //   floordiv(a,b)
         //     == floordiv(a + b*c, b) - c
         //     == truncdiv(a + b*c, b) - c
-        IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
+        IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
         PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
         PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * 
ceildiv);
         return truncdiv(offset_numerator, op->b) - ceildiv;
@@ -214,10 +224,20 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
       // If the numerator's lower bound is known, express the floormod
       // in terms of truncmod using only positive operands.
-      arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
-      if (const_int_bound->min_value < 0 &&
-          const_int_bound->min_value >
-              
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
+
+      // The optimization below rewrites expressions involving `-a_min + (b - 
1)`.
+      // Without proper bounds checking, this expression may overflow the dtype
+      // maximum, leading to non-equivalent transformations.
+      // To ensure safety, we require:
+      //   b_max - a_min <= max_value_of_dtype + 1
+      // This provides a conservative upper bound that prevents overflow and
+      // preserves the original semantics.
+      arith::ConstIntBound const_int_bound_a = 
analyzer_->const_int_bound(op->a);
+      arith::ConstIntBound const_int_bound_b = 
analyzer_->const_int_bound(op->b);
+      const int64_t max_value_of_dtype =
+          Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
+      if (const_int_bound_a->min_value < 0 &&
+          const_int_bound_b->max_value - const_int_bound_a->min_value <= 
max_value_of_dtype + 1) {
         // The goal is to write floormod(a,b) in terms of truncdiv and 
truncmod,
         // without using negative operands.
         //
@@ -247,7 +267,7 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
         //   floormod(a,b)
         //     == floormod(a + b*c, b)
         //     == truncmod(a + b*c, b)
-        IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
+        IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
         PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
         PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * 
ceildiv);
         return truncmod(offset_numerator, op->b);
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py 
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index 864b24bc0f..63f37e6f41 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -92,6 +92,12 @@ def test_lower_floordiv():
         # const power of two
         res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, 
dtype=dtype)))
         check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, 
b: a // b)
+        # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
+        res = lower_intrin(
+            [x, y],
+            tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
+        )
+        check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, 
b: (a + 4) // b)
 
 
 @tvm.testing.requires_llvm
@@ -115,6 +121,12 @@ def test_lower_floormod():
         # const power of two
         res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, 
dtype=dtype)))
         check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, 
b: a % b)
+        # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
+        res = lower_intrin(
+            [x, y],
+            tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
+        )
+        check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, 
b: (a + 4) % b)
 
 
 if __name__ == "__main__":

Reply via email to