This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f715671412df16eeb9ed4455597be6e0b67b5a20 Author: Hongyi Jin <[email protected]> AuthorDate: Thu May 28 14:57:11 2026 -0400 fix(arith): gate canonical-simplify LT Case 2 on extra scale == +1 (#651) CanonicalSimplifier::Impl::VisitExpr_(LTNode) Case 2 rewrites a "scaled-by-d sum plus a single leftover split" comparison S + xn < 0 ⇔ S/d + (xn // d) < 0 where d = gcd(scales) into one where the leftover yn % m gets replaced by floormod(floordiv(yn, d*L), m/(d*L)). The Case 1 derivation that justifies dropping the remainder xn % d ∈ [0, d) only works when xn ≥ 0. With scale = -1 the equivalence becomes ≤ rather than <, and the rewrite silently strengthens the predicate by dropping the boundary case S/d == xn // d. This surfaced as a miscompile in TIRx kernels that mask a per-lane write by `row > col`, where `row = (lane_id // 4) + 16 * warp_id` and `col = 2 * (lane_id % 4)` are independent projections of the same lane id. After CSE+inlining the comparison hit canonical_simplify with the divided projection on the LHS (scale = -1), and Case 2 folded `2*(tx%4) < 16*warp + (tx%32)//4` into a plain `0 < warp_id`, zeroing every thread that should have written `val` in warp 0. The same path also folded other configurations (e.g. `0 < (tx%32) - 8*warp`) all the way to False. Gate Case 2 with `extra->args[0]->scale == 1`. The original target shape (`(yn % m)` with positive scale and lower_factor=1, as well as the scale=+1 + lower_factor>1 generalization) is unchanged; both are covered by the existing `test_simplify_le` cases and by the new `test_simplify_le_negative_scale_extra` regression test, which also pins the buggy scale=-1 shape to its unsimplified form and re-asserts that the truly-always-true `r=2` variant still folds to True. --- src/arith/canonical_simplify.cc | 11 +++++- .../python/arith/test_arith_canonical_simplify.py | 44 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ac1b89f97a..0001afbdfe 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { // Case 1. 0 <= xn < d divisible.CopyOnWrite()->DivideBy(gcd); return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); - } else if (extra->args.size() == 1 && + } else if (extra->args.size() == 1 && extra->args[0]->scale == 1 && extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf && extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) { - // Case 2. xn == yn % m, where m % d == 0 + // Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0. + // S + xn < 0 with S divisible by d ⇔ S/d + xn // d < 0, because + // xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument, + // and xn // d = (yn // (d*L)) % (m/(d*L)). + // The scale must be +1: with scale = -1 the equivalence becomes ≤ + // rather than <, so the rewrite would strengthen the predicate and + // silently drop the boundary S/d == xn // d (e.g. row > col where + // row and col are independent projections of the same lane id). divisible.CopyOnWrite()->DivideBy(gcd); const auto split_expr = extra->args[0]; int64_t lower_factor = gcd * extra->args[0]->lower_factor; diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index ce89db9c99..49f480bcce 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -488,5 +488,49 @@ def test_simplify_le(): ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0) +def test_simplify_le_negative_scale_extra(): + """Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not + fire when the leftover split term has a negative scale. + + The rewrite ``S + xn < 0 ⇔ S/d + xn // d < 0`` is only sound when + the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence + becomes ``≤`` rather than ``<`` and the rewrite silently strengthens + the predicate. The original bug surfaced as ``row > col`` masks of + ``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k`` + comparisons (lower-triangle writes were silently dropped on the + boundary warp). + """ + ck = CanonicalChecker() + tx = tvm.tirx.Var("tx", "int32") + warp = tvm.tirx.Var("warp", "int32") + ck.analyzer.bind(tx, tvm.ir.Range(0, 128)) + ck.analyzer.bind(warp, tvm.ir.Range(0, 4)) + + # Same-source joint projection: the comparison genuinely depends on tx + # at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False), + # so the simplifier must keep both sides. Pre-fix this folded to + # ``0 < warp`` and dropped every True case in warp 0. + expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + ck.verify(expr, expr) + + # The simpler ``scale = -1`` with ``lower_factor = 1`` shape. Pre-fix + # this folded to ``False`` (drops all warp >= 1 cases where the rhs + # actually exceeds 8*warp). + expr = warp * 8 < (tx % 32) + ck.verify(expr, expr) + + # The corresponding ``scale = +1`` Case 2 path (the rewrite this guards) + # must still optimize — verifies we did not over-restrict. + x1 = tvm.tirx.Var("x1", "int32") + y1 = tvm.tirx.Var("y1", "int32") + ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15) + + # The truly-always-true comparison that arises from the same kernel + # (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still + # fold to True so the masked store can be elided. + expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8 + ck.verify(expr_true, tvm.tirx.const(True, "bool")) + + if __name__ == "__main__": tvm.testing.main()
