Canonicalize right shift non-equality comparisons with constants by
turn them into a comparison with a left shifted constant.  Assuming the
generic format:

(A >> CST1) CMP CST2

For CMP (<, >=) we'll compare A with CST2 left shifted by CST1:

- (A >> CST1) < CST2  -> A < (CST2 << CST1)
- (A >> CST1) >= CST2 -> A >= (CST2 << CST1)

And for CMP (<=, >) we need to IOR the lower CST1 bits from the left
shift:

- (A >> CST1) <= CST2 -> A <= (CST2 << CST1) | mask
- (A >> CST1) > CST2  -> A > (CST2 << CST1) | mask

Given that the right hand side changes involves just constants, in the
end we'll replace a rshift + cmp with just a cmp.

Bootstrapped and regression tested in x86, aarch64 and RISC-V.

        PR tree-optimization/124808

gcc/ChangeLog:

        * match.pd(`(A >> CST1) CMP CST2`): New pattern.

gcc/testsuite/ChangeLog:

        * gcc.dg/tree-ssa/pr124808-2.c: New test.
        * gcc.dg/tree-ssa/pr124808.c: New test.
---

Changes from v3:
- check manually for a potential overflow/sign bit set when doing
  the lshift;
- use tree_to_uhwi instead of TREE_INT_CST_LOW in 'mask_len';
- v3 link: https://gcc.gnu.org/pipermail/gcc-patches/2026-April/713064.html

 gcc/match.pd                               | 62 +++++++++++++++++
 gcc/testsuite/gcc.dg/tree-ssa/pr124808-2.c | 44 ++++++++++++
 gcc/testsuite/gcc.dg/tree-ssa/pr124808.c   | 78 ++++++++++++++++++++++
 3 files changed, 184 insertions(+)
 create mode 100644 gcc/testsuite/gcc.dg/tree-ssa/pr124808-2.c
 create mode 100644 gcc/testsuite/gcc.dg/tree-ssa/pr124808.c

diff --git a/gcc/match.pd b/gcc/match.pd
index 4ed058f6e18..8ce7bef5f90 100644
--- a/gcc/match.pd
+++ b/gcc/match.pd
@@ -4959,6 +4959,68 @@ DEFINE_INT_AND_FLOAT_ROUND_FN (RINT)
        && TYPE_UNSIGNED (TREE_TYPE (@1)))
    (cmp @0 @1))))
 
+/* PR124808: (A >> CST1) CMP CST2 -> A CMP (CST2 << CST1)
+   Canonicalize non-equality comparisons between a right
+   shift and a constant, turning it into a comparison
+   with a constant that is left shifted.  If we view A as:
+
+   A: "|--- CST2 ----|--CST1--|"
+
+   A >> CST1 will be equal to CST2 for all A values in the
+   range CST2 << CST1 to (CST2 << CST1) | 1s_mask (CST1).
+
+   Therefore:
+   - (A >> CST1) < CST2 -> A < (CST2 << CST1), A must
+     be smaller than all values from the range;
+   - (A >> CST1) <= CST2 -> A <= (CST2 << CST1) | mask,
+     A must be smaller or equal than the range end;
+   - (A >> CST1) > CST2 -> A > (CST2 << CST1) | mask,
+     A must be greater than all values from the range;
+   - (A >> CST1) >= CST2 -> A >= (CST2 << CST1), A must
+     be greater or equal than the range start.
+
+   We're also using "single_use" and wrapping around "if GIMPLE"
+   because (1) this class "VA1 LSHIFT/RSHIFT VAL2 CMP VAL3" of
+   optimizations tend to match CTZ|CLZ builtin patterns and we
+   don't want to trip on them and (2) we will get in the way of
+   certain optimizations (see ARM's sat-1.c test) that are done
+   using GENERIC due to how forwprop currently works.  */
+#if GIMPLE
+(for cmp (le lt ge gt)
+ (simplify
+  (cmp (rshift@3 @0 INTEGER_CST@1) INTEGER_CST@2)
+  (if (INTEGRAL_TYPE_P (TREE_TYPE (@0))
+       && single_use (@3)
+       && tree_fits_uhwi_p (@1)
+       && tree_to_uhwi (@1) < TYPE_PRECISION (TREE_TYPE (@0))
+       && tree_int_cst_sgn (@2) >= 0
+       && tree_to_uhwi (@1) < TYPE_PRECISION (TREE_TYPE (@2))
+       && (wi::to_wide (@2) == 0
+          /* If @2 is nonzero, verify if we'll overflow when
+             doing @2 << @1 by checking if the shift amount
+             can start throwing 1s away if type_unsigned, or
+             if the shift amount can reach the sign bit.  */
+          || (TYPE_UNSIGNED (TREE_TYPE (@2))
+              && wi::leu_p (wi::to_wide (@1), wi::clz (wi::to_wide (@2))))
+          || (!TYPE_UNSIGNED (TREE_TYPE (@2))
+              && wi::ltu_p (wi::to_wide (@1), wi::clz (wi::to_wide (@2))))))
+
+   /* No need to set the lower @1 bits of the resulting
+      lshift for "<" and ">=" comparisons.  */
+   (if (cmp == LT_EXPR || cmp == GE_EXPR)
+    (cmp @0 (lshift @2 @1))
+
+     /* For "<=" and ">" set the lower @1 lshift bits.  */
+     (with {
+       tree type2 = TREE_TYPE (@2);
+       unsigned prec = TYPE_PRECISION (type2);
+       unsigned mask_len = tree_to_uhwi (@1);
+       wide_int cst1_mask = wi::mask (mask_len, false, prec);
+      }
+       (cmp @0 (bit_ior (lshift @2 @1)
+                       { wide_int_to_tree (type2, cst1_mask); })))))))
+#endif
+
 /* Rewrite an LROTATE_EXPR by a constant into an
    RROTATE_EXPR by a new constant.  */
 (simplify
diff --git a/gcc/testsuite/gcc.dg/tree-ssa/pr124808-2.c 
b/gcc/testsuite/gcc.dg/tree-ssa/pr124808-2.c
new file mode 100644
index 00000000000..bfab8109ce6
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/tree-ssa/pr124808-2.c
@@ -0,0 +1,44 @@
+/* { dg-additional-options -O2 } */
+/* { dg-additional-options -fdump-tree-forwprop1 } */
+
+long* SetupPrecalculatedData1 (long* a) {
+  long b = 1;
+  int i;
+  for (i = 0; i < 64; i++) {
+    if(i>>3 < 7)
+      a[i] += (b<<(i+8));
+  }
+  return a;
+}
+
+long* SetupPrecalculatedData2 (long* a) {
+  long b = 1;
+  int i;
+  for (i = 0; i <= 64; i++) {
+    if(i>>3 < 7)
+      a[i] += (b<<(i+8));
+  }
+  return a;
+}
+
+long* SetupPrecalculatedData3 (long* a) {
+  long b = 1;
+  int i;
+  for (i = 0; i < 64; i++) {
+    if(i>>3 > 7)
+      a[i] += (b<<(i+8));
+  }
+  return a;
+}
+
+long* SetupPrecalculatedData4 (long* a) {
+  long b = 1;
+  int i;
+  for (i = 0; i < 64; i++) {
+    if(i>>3 >= 7)
+      a[i] += (b<<(i+8));
+  }
+  return a;
+}
+
+/* { dg-final { scan-tree-dump-times ">> 3" 0 forwprop1 } } */
diff --git a/gcc/testsuite/gcc.dg/tree-ssa/pr124808.c 
b/gcc/testsuite/gcc.dg/tree-ssa/pr124808.c
new file mode 100644
index 00000000000..5f3c1f0a8c4
--- /dev/null
+++ b/gcc/testsuite/gcc.dg/tree-ssa/pr124808.c
@@ -0,0 +1,78 @@
+/* { dg-do run } */
+/* { dg-options "-O2" } */
+
+void abort(void);
+
+/* Macro adapted from builtin-object-size-common.h  */
+#define FAIL() \
+  do { \
+    __builtin_printf ("Failure at line: %d\n", __LINE__);      \
+    abort();                                                   \
+  } while (0)
+
+#define SHIFTVAL 3
+#define CMPVAL 1
+
+long setValue1 (long in)
+{
+  if (in >> SHIFTVAL < CMPVAL)
+    return in += SHIFTVAL;
+  return -1;
+}
+
+long setValue2 (long in)
+{
+  if (in >> SHIFTVAL <= CMPVAL)
+    return in += SHIFTVAL;
+  return -1;
+}
+
+long setValue3 (long in)
+{
+  if (in >> SHIFTVAL > CMPVAL)
+    return in += SHIFTVAL;
+  return -1;
+}
+
+long setValue4 (long in)
+{
+  if (in >> SHIFTVAL >= CMPVAL)
+    return in += SHIFTVAL;
+  return -1;
+}
+
+int main (void) {
+  /* setValue1: in << 3 < 1;  */
+  if (setValue1 (7) != 10)
+    FAIL ();
+  if (setValue1 (8) != -1)
+    FAIL ();
+
+  /* setValue2: in << 3 <= 1;  */
+  if (setValue2 (7) != 10)
+    FAIL ();
+  if (setValue2 (8) != 11)
+    FAIL ();
+  if (setValue2 (15) != 18)
+    FAIL ();
+  if (setValue2 (16) != -1)
+    FAIL ();
+
+  /* setValue3: in << 3 > 1;  */
+  if (setValue3 (15) != -1)
+    FAIL ();
+  if (setValue3 (16) != 19)
+    FAIL ();
+
+  /* setValue4: in << 3 >= 1;  */
+  if (setValue4 (7) != -1)
+    FAIL ();
+  if (setValue4 (8) != 11)
+    FAIL ();
+  if (setValue4 (15) != 18)
+    FAIL ();
+  if (setValue4 (16) != 19)
+    FAIL ();
+
+  return 0;
+}
-- 
2.43.0

Reply via email to