This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b1bd6d431 [TIR] Further robustify floordiv/mod intrin lowering to 
prevent overflow (#18699)
4b1bd6d431 is described below

commit 4b1bd6d431a662635b4f2817b2f7edd8be8ccbd3
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Jan 30 13:17:48 2026 -0500

    [TIR] Further robustify floordiv/mod intrin lowering to prevent overflow 
(#18699)
    
    This PR further robustifies floordiv/mod intrin lowering in cases where
    we can to shift negative values to positive range but also need to
    carefully prevent overflow in intermediate compiler checks.
---
 src/tir/transforms/lower_intrin.cc                 | 161 ++++++++-------------
 .../test_tir_transform_lower_intrin.py             |  89 +++++++++---
 2 files changed, 125 insertions(+), 125 deletions(-)

diff --git a/src/tir/transforms/lower_intrin.cc 
b/src/tir/transforms/lower_intrin.cc
index 6a7d2b2776..ef844d9e05 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -115,59 +115,15 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
       if (analyzer_->CanProveGreaterEqual(op->a, 0) || 
analyzer_->CanProveGreaterEqual(e, 0)) {
         return truncdiv(op->a, op->b);
       }
-
-      // If the numerator's lower bound is known, express the floordiv
-      // in terms of truncdiv using only positive operands.
-
-      // The optimization below rewrites expressions involving `-a_min + (b - 
1)`.
-      // Without proper bounds checking, this expression may overflow the dtype
-      // maximum, leading to non-equivalent transformations.
-      // To ensure safety, we require:
-      //   b_max - a_min <= max_value_of_dtype + 1
-      // This provides a conservative upper bound that prevents overflow and
-      // preserves the original semantics.
-      arith::ConstIntBound const_int_bound_a = 
analyzer_->const_int_bound(op->a);
-      arith::ConstIntBound const_int_bound_b = 
analyzer_->const_int_bound(op->b);
-      const int64_t max_value_of_dtype =
-          Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
-      if (const_int_bound_a->min_value < 0 &&
-          const_int_bound_b->max_value - const_int_bound_a->min_value <= 
max_value_of_dtype + 1) {
-        // The goal is to write floordiv(a,b) in terms of truncdiv, without 
using
-        // negative operands.
-        //
-        // For any integer c
-        //
-        //   floordiv(a,b) == floordiv(a + b*c - b*c, b)
-        //                 == floordiv(a + b*c, b) - c
-        //
-        // Choosing `c = ceildiv(-a_min, b)`.  This can be rewritten in terms 
of
-        // truncdiv as follows.
-        //
-        //   c == ceildiv(-a_min,b)
-        //     == floordiv(-a_min + (b-1), b)
-        //     == truncdiv(-a_min + (b-1), b)
-        //
-        // When substituted into `a + b*c`, this results in a positive 
argument.
-        //
-        //   a + b*c
-        //     == a + b*ceildiv(-a_min,b)
-        //     == a - b*floordiv(a_min,b)
-        //     >= a - b*floordiv(a,b)
-        //     == floormod(a, b)
-        //     >= 0
-        //
-        // Since the argument is positive, this allows floordiv to be written 
as
-        // followed.
-        //
-        //   floordiv(a,b)
-        //     == floordiv(a + b*c, b) - c
-        //     == truncdiv(a + b*c, b) - c
-        IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
-        PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
-        PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * 
ceildiv);
-        return truncdiv(offset_numerator, op->b) - ceildiv;
+      if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
+        int64_t b_value = b_as_intimm->value;
+        if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, 
b_value)) {
+          int64_t c_value = *opt_c_value;
+          // now we can safely lower to truncdiv
+          return truncdiv(op->a + make_const(dtype, b_value * c_value), op->b) 
-
+                 make_const(dtype, c_value);
+        }
       }
-
       DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
       PrimExpr rdiv = truncdiv(op->a, op->b);
       PrimExpr rmod = truncmod(op->a, op->b);
@@ -221,58 +177,14 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
       if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
         return truncmod(op->a, op->b);
       }
-
-      // If the numerator's lower bound is known, express the floormod
-      // in terms of truncmod using only positive operands.
-
-      // The optimization below rewrites expressions involving `-a_min + (b - 
1)`.
-      // Without proper bounds checking, this expression may overflow the dtype
-      // maximum, leading to non-equivalent transformations.
-      // To ensure safety, we require:
-      //   b_max - a_min <= max_value_of_dtype + 1
-      // This provides a conservative upper bound that prevents overflow and
-      // preserves the original semantics.
-      arith::ConstIntBound const_int_bound_a = 
analyzer_->const_int_bound(op->a);
-      arith::ConstIntBound const_int_bound_b = 
analyzer_->const_int_bound(op->b);
-      const int64_t max_value_of_dtype =
-          Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
-      if (const_int_bound_a->min_value < 0 &&
-          const_int_bound_b->max_value - const_int_bound_a->min_value <= 
max_value_of_dtype + 1) {
-        // The goal is to write floormod(a,b) in terms of truncdiv and 
truncmod,
-        // without using negative operands.
-        //
-        // For any integer c
-        //
-        //   floormod(a, b) == floormod(a + b*c, b)
-        //
-        // Choosing `c = ceildiv(-a_min, b)`.  This can be rewritten in terms 
of
-        // truncdiv as follows.
-        //
-        //   c == ceildiv(-a_min,b)
-        //     == floordiv(-a_min + (b-1), b)
-        //     == truncdiv(-a_min + (b-1), b)
-        //
-        // When substituted into `a + b*c`, this results in a positive 
argument.
-        //
-        //   a + b*c
-        //     == a + b*ceildiv(-a_min,b)
-        //     == a - b*floordiv(a_min,b)
-        //     >= a - b*floordiv(a,b)
-        //     == floormod(a, b)
-        //     >= 0
-        //
-        // Since the argument is positive, this allows floordiv to be written 
as
-        // followed.
-        //
-        //   floormod(a,b)
-        //     == floormod(a + b*c, b)
-        //     == truncmod(a + b*c, b)
-        IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
-        PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
-        PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * 
ceildiv);
-        return truncmod(offset_numerator, op->b);
+      if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
+        int64_t b_value = b_as_intimm->value;
+        if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, 
b_value)) {
+          int64_t c_value = *opt_c_value;
+          // floormod(a, b) == floormod(a + b*c, b)  == truncmod(a + b*c, b)
+          return truncmod(op->a + make_const(dtype, c_value * b_value), op->b);
+        }
       }
-
       DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
       // NOTE:condition on b >= 0.
       // mod(a, b) < 0 will imply we are doing ceildiv,
@@ -388,6 +300,49 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
+  /*!
+   * \brief Try to find a shift co-efficient c such that a + b*c positive and 
does not overflow.
+   *
+   * \param a the dividend
+   * \param b_value the divisor
+   * \return the shift co-efficient c, or nullopt if not found
+   */
+  std::optional<int64_t> TryFindShiftCoefficientForPositiveRange(const 
PrimExpr& a,
+                                                                 int64_t 
b_value) {
+    if (b_value <= 0) {
+      return std::nullopt;
+    }
+    // NOTE: we need to be very careful in the checks below, to make sure
+    // all the intermediate calculations in both compiler checks and runtime 
checks
+    // do not overflow
+    arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(a);
+    if (const_int_bound_a->min_value >= 0) {
+      return std::nullopt;
+    }
+    const int64_t max_value_of_dtype =
+        Downcast<IntImm>(tvm::max_value(a->dtype.element_of()))->value;
+
+    // NOTE: ensures that (b-1) - a_min does not overflow
+    // also note: max_value_of_dtype + const_int_bound_a->min_value won't 
overflow
+    // since a_min is negative, adding it to a positive value will not overflow
+    if (b_value - 1 > max_value_of_dtype + const_int_bound_a->min_value) {
+      return std::nullopt;
+    }
+    int64_t c_value = ((b_value - 1) - const_int_bound_a->min_value) / b_value;
+    ICHECK_GT(c_value, 0);
+    // NOTE: the c_value * b_value risks in overflow
+    if (c_value > max_value_of_dtype / b_value) return std::nullopt;
+    // need to check if the offset numerator will overflow
+    // to ensure if don't overflow, we need to use max_value_of_dtype - 
b_value * c_value
+    // note that b_value * c_value is positive, max_value_of_dtype is also 
positive, so the
+    // subtraction will not overflow
+    if (const_int_bound_a->max_value > max_value_of_dtype - b_value * c_value) 
{
+      // a + b * c risks overflow
+      return std::nullopt;
+    }
+    return c_value;
+  }
+
   // attribute maps, shared only when FLegalize == FLowerIntrinsic
   std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
   FLowerGeneral fma_{nullptr};
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py 
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index 63f37e6f41..a0a6ab2508 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -35,24 +35,35 @@ def lower_intrin(params, stmt):
     return stmt.value if lower_expr else stmt.body
 
 
-def check_value(expr, vx, vy, data, fref):
+def check_value(expr, variables, data, fref):
+    """
+    Check that expr evaluates to fref(*row) for each row in data.
+    variables: list of TIR vars [x] or [x, y] bound to the columns of data.
+    data: list of tuples, each tuple has len(variables) elements.
+    """
     n = len(data)
-    A = te.placeholder((n,), name="A", dtype=expr.dtype)
-    B = te.placeholder((n,), name="B", dtype=expr.dtype)
+    num_vars = len(variables)
+    assert num_vars >= 1 and all(len(row) == num_vars for row in data)
+
+    placeholders = [
+        te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in 
range(num_vars)
+    ]
 
     def make_binds(i):
         x = expr
-        x = tvm.tir.Let(vx, A[i], x)
-        x = tvm.tir.Let(vy, B[i], x)
+        for j in range(num_vars - 1, -1, -1):
+            x = tvm.tir.Let(variables[j], placeholders[j][i], x)
         return x
 
     C = te.compute((n,), make_binds)
-    f = tvm.compile(te.create_prim_func([A, B, C]), "llvm")
-    a = tvm.runtime.tensor(np.array([x for x, y in data], dtype=expr.dtype))
-    b = tvm.runtime.tensor(np.array([y for x, y in data], dtype=expr.dtype))
-    c = tvm.runtime.tensor(np.zeros(len(data), dtype=expr.dtype))
-    f(a, b, c)
-    cref = np.array([fref(x, y) for x, y in data])
+    f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm")
+    arrays = [
+        tvm.runtime.tensor(np.array([row[j] for row in data], 
dtype=variables[j].dtype))
+        for j in range(num_vars)
+    ]
+    c = tvm.runtime.tensor(np.zeros(n, dtype=expr.dtype))
+    f(*arrays, c)
+    cref = np.array([fref(*row) for row in data])
     np.testing.assert_equal(c.numpy(), cref)
 
 
@@ -75,29 +86,29 @@ def test_lower_floordiv():
         zero = tvm.tir.const(0, dtype)
         # no constraints
         res = lower_intrin([x, y], tvm.te.floordiv(x, y))
-        check_value(res, x, y, data, lambda a, b: a // b)
+        check_value(res, [x, y], data, lambda a, b: a // b)
         # rhs >= 0
         res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, 
y), zero))
-        check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
+        check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0)
         # involves max
         res = lower_intrin(
             [x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), 
zero), zero)
         )
-        check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 
0)
+        check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0 
else 0)
         # lhs >= 0
         res = lower_intrin(
             [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.te.floordiv(x, y), zero)
         )
-        check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 
else 0)
+        check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0 
else 0)
         # const power of two
         res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, 
dtype=dtype)))
-        check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, 
b: a // b)
+        check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda 
a, b: a // b)
         # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
         res = lower_intrin(
             [x, y],
             tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
         )
-        check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, 
b: (a + 4) // b)
+        check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda 
a, b: (a + 4) // b)
 
 
 @tvm.testing.requires_llvm
@@ -109,26 +120,60 @@ def test_lower_floormod():
         zero = tvm.tir.const(0, dtype)
         # no constraints
         res = lower_intrin([x, y], tvm.te.floormod(x, y))
-        check_value(res, x, y, data, lambda a, b: a % b)
+        check_value(res, [x, y], data, lambda a, b: a % b)
         # rhs >= 0
         res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, 
y), zero))
-        check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
+        check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0)
         # lhs >= 0
         res = lower_intrin(
             [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), 
tvm.te.floormod(x, y), zero)
         )
-        check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 
else 0)
+        check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0 
else 0)
         # const power of two
         res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, 
dtype=dtype)))
-        check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, 
b: a % b)
+        check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda 
a, b: a % b)
         # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
         res = lower_intrin(
             [x, y],
             tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), 
tvm.tir.const(5, dtype=dtype)),
         )
-        check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, 
b: (a + 4) % b)
+        check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda 
a, b: (a + 4) % b)
+
+
[email protected]_llvm
+def test_lower_floordiv_overflow_checks():
+    """
+    Regression tests for overflow checks in 
TryFindShiftCoefficientForPositiveRange.
+    Divisor is constant 3 (not 1 to avoid CSE, not power-of-two so we don't 
take the shift path).
+    Reuses lower_intrin and check_value; overflow tests use one var [x].
+    """
+    # Check 3: (b-1) - a_min must not overflow (numerator and C++ int64).
+    # x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 - 
(-2^63) > LLONG_MAX.
+    x = te.var("x", dtype="int64")
+    res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64")))
+    data_check3 = [(-(2**63),), (0,), (100,)]
+    check_value(res, [x], data_check3, lambda a: a // 3)
+
+    # Check 4: c_value * b_value must not overflow dtype.
+    # x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923; 
10923*3 > 32767.
+    x = te.var("x", dtype="int16")
+    res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16")))
+    data_check4 = [(-32768,), (0,), (100,)]
+    check_value(res, [x], data_check4, lambda a: a // 3)
+
+    # Check 5: a_max + b*c must not overflow (offset numerator).
+    # tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; 
a_max + 12 > 32767.
+    # In practice this path may not be triggered. This test still validates 
correct lowering.
+    x = te.var("x", dtype="int16")
+    clamped = tvm.tir.min(
+        tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758, 
"int16")
+    )
+    res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3, 
"int16")))
+    data_check5 = [(-10,), (0,), (32758,), (32757,)]
+    check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) // 
3)
 
 
 if __name__ == "__main__":
     test_lower_floordiv()
     test_lower_floormod()
+    test_lower_floordiv_overflow_checks()

Reply via email to