https://github.com/python/cpython/commit/6d972e0104097e476118686ff7c84ea238ecafc3
commit: 6d972e0104097e476118686ff7c84ea238ecafc3
branch: main
author: reiden <[email protected]>
committer: Fidget-Spinner <[email protected]>
date: 2026-01-24T10:02:08Z
summary:

gh-130415: Narrow types to constants in branches involving specialized 
comparisons with a constant (GH-144150)

files:
M Include/internal/pycore_optimizer_types.h
M Lib/test/test_capi/test_opt.py
M Python/optimizer_analysis.c
M Python/optimizer_bytecodes.c
M Python/optimizer_cases.c.h
M Python/optimizer_symbols.c

diff --git a/Include/internal/pycore_optimizer_types.h 
b/Include/internal/pycore_optimizer_types.h
index a879ca26ce7b63..b4b93e8353812a 100644
--- a/Include/internal/pycore_optimizer_types.h
+++ b/Include/internal/pycore_optimizer_types.h
@@ -76,6 +76,8 @@ typedef struct {
 typedef enum {
     JIT_PRED_IS,
     JIT_PRED_IS_NOT,
+    JIT_PRED_EQ,
+    JIT_PRED_NE,
 } JitOptPredicateKind;
 
 typedef struct {
diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index f224984777500a..ff3adb1d4542e0 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -890,6 +890,138 @@ def testfunc(n):
         self.assertLessEqual(len(guard_nos_unicode_count), 1)
         self.assertIn("_COMPARE_OP_STR", uops)
 
+    def test_compare_int_eq_narrows_to_constant(self):
+        def f(n):
+            def return_1():
+                return 1
+
+            hits = 0
+            v = return_1()
+            for _ in range(n):
+                if v == 1:
+                    if v == 1:
+                        hits += 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1)
+
+    def test_compare_int_ne_narrows_to_constant(self):
+        def f(n):
+            def return_1():
+                return 1
+
+            hits = 0
+            v = return_1()
+            for _ in range(n):
+                if v != 1:
+                    hits += 1000
+                else:
+                    if v == 1:
+                        hits += v + 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD * 2)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1)
+
+    def test_compare_float_eq_narrows_to_constant(self):
+        def f(n):
+            def return_tenth():
+                return 0.1
+
+            hits = 0
+            v = return_tenth()
+            for _ in range(n):
+                if v == 0.1:
+                    if v == 0.1:
+                        hits += 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1)
+
+    def test_compare_float_ne_narrows_to_constant(self):
+        def f(n):
+            def return_tenth():
+                return 0.1
+
+            hits = 0
+            v = return_tenth()
+            for _ in range(n):
+                if v != 0.1:
+                    hits += 1000
+                else:
+                    if v == 0.1:
+                        hits += 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1)
+
+    def test_compare_str_eq_narrows_to_constant(self):
+        def f(n):
+            def return_hello():
+                return "hello"
+
+            hits = 0
+            v = return_hello()
+            for _ in range(n):
+                if v == "hello":
+                    if v == "hello":
+                        hits += 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1)
+
+    def test_compare_str_ne_narrows_to_constant(self):
+        def f(n):
+            def return_hello():
+                return "hello"
+
+            hits = 0
+            v = return_hello()
+            for _ in range(n):
+                if v != "hello":
+                    hits += 1000
+                else:
+                    if v == "hello":
+                        hits += 1
+            return hits
+
+        res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD)
+        self.assertEqual(res, TIER2_THRESHOLD)
+        self.assertIsNotNone(ex)
+        uops = get_opnames(ex)
+
+        # Constant narrowing allows constant folding for second comparison
+        self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1)
+
     @unittest.skip("gh-139109 WIP")
     def test_combine_stack_space_checks_sequential(self):
         def dummy12(x):
diff --git a/Python/optimizer_analysis.c b/Python/optimizer_analysis.c
index 6c381ab184fd85..65c9239f1ff427 100644
--- a/Python/optimizer_analysis.c
+++ b/Python/optimizer_analysis.c
@@ -250,6 +250,11 @@ add_op(JitOptContext *ctx, _PyUOpInstruction *this_instr,
 #define sym_new_predicate _Py_uop_sym_new_predicate
 #define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing
 
+/* Comparison oparg masks */
+#define COMPARE_LT_MASK 2
+#define COMPARE_GT_MASK 4
+#define COMPARE_EQ_MASK 8
+
 #define JUMP_TO_LABEL(label) goto label;
 
 static int
diff --git a/Python/optimizer_bytecodes.c b/Python/optimizer_bytecodes.c
index 27b974f372df99..38cd088d9fb030 100644
--- a/Python/optimizer_bytecodes.c
+++ b/Python/optimizer_bytecodes.c
@@ -521,21 +521,51 @@ dummy_func(void) {
     }
 
     op(_COMPARE_OP_INT, (left, right -- res, l, r)) {
-        res = sym_new_type(ctx, &PyBool_Type);
+        int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+
+        if (cmp_mask == COMPARE_EQ_MASK) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+        }
+        else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+        }
+        else {
+            res = sym_new_type(ctx, &PyBool_Type);
+        }
         l = left;
         r = right;
         REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
     }
 
     op(_COMPARE_OP_FLOAT, (left, right -- res, l, r)) {
-        res = sym_new_type(ctx, &PyBool_Type);
+        int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+
+        if (cmp_mask == COMPARE_EQ_MASK) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+        }
+        else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+        }
+        else {
+            res = sym_new_type(ctx, &PyBool_Type);
+        }
         l = left;
         r = right;
         REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
     }
 
     op(_COMPARE_OP_STR, (left, right -- res, l, r)) {
-        res = sym_new_type(ctx, &PyBool_Type);
+        int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+
+        if (cmp_mask == COMPARE_EQ_MASK) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+        }
+        else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+            res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+        }
+        else {
+            res = sym_new_type(ctx, &PyBool_Type);
+        }
         l = left;
         r = right;
         REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res);
diff --git a/Python/optimizer_cases.c.h b/Python/optimizer_cases.c.h
index 3b25533e07f743..e9405473fe2e0d 100644
--- a/Python/optimizer_cases.c.h
+++ b/Python/optimizer_cases.c.h
@@ -2118,7 +2118,16 @@
             JitOptRef r;
             right = stack_pointer[-1];
             left = stack_pointer[-2];
-            res = sym_new_type(ctx, &PyBool_Type);
+            int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+            if (cmp_mask == COMPARE_EQ_MASK) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+            }
+            else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+            }
+            else {
+                res = sym_new_type(ctx, &PyBool_Type);
+            }
             l = left;
             r = right;
             if (
@@ -2178,7 +2187,16 @@
             JitOptRef r;
             right = stack_pointer[-1];
             left = stack_pointer[-2];
-            res = sym_new_type(ctx, &PyBool_Type);
+            int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+            if (cmp_mask == COMPARE_EQ_MASK) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+            }
+            else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+            }
+            else {
+                res = sym_new_type(ctx, &PyBool_Type);
+            }
             l = left;
             r = right;
             if (
@@ -2242,7 +2260,16 @@
             JitOptRef r;
             right = stack_pointer[-1];
             left = stack_pointer[-2];
-            res = sym_new_type(ctx, &PyBool_Type);
+            int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | 
COMPARE_EQ_MASK);
+            if (cmp_mask == COMPARE_EQ_MASK) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ);
+            }
+            else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) {
+                res = sym_new_predicate(ctx, left, right, JIT_PRED_NE);
+            }
+            else {
+                res = sym_new_type(ctx, &PyBool_Type);
+            }
             l = left;
             r = right;
             if (
diff --git a/Python/optimizer_symbols.c b/Python/optimizer_symbols.c
index a9640aaa5072c5..51cf6e189f0f49 100644
--- a/Python/optimizer_symbols.c
+++ b/Python/optimizer_symbols.c
@@ -875,9 +875,11 @@ _Py_uop_sym_apply_predicate_narrowing(JitOptContext *ctx, 
JitOptRef ref, bool br
 
     bool narrow = false;
     switch(pred.kind) {
+        case JIT_PRED_EQ:
         case JIT_PRED_IS:
             narrow = branch_is_true;
             break;
+        case JIT_PRED_NE:
         case JIT_PRED_IS_NOT:
             narrow = !branch_is_true;
             break;
@@ -1300,11 +1302,11 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), 
PyObject *Py_UNUSED(ignored))
     TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (None)");
     TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_None, "predicate 
narrowing did not narrow subject to None");
 
-    // Test narrowing subject to numerical constant
+    // Test narrowing subject to numerical constant from is comparison
     subject = _Py_uop_sym_new_unknown(ctx);
     PyObject *one_obj = PyLong_FromLong(1);
     JitOptRef const_one = _Py_uop_sym_new_const(ctx, one_obj);
-    if (PyJitRef_IsNull(subject) || PyJitRef_IsNull(const_one)) {
+    if (PyJitRef_IsNull(subject) || one_obj == NULL || 
PyJitRef_IsNull(const_one)) {
         goto fail;
     }
     ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_IS);
@@ -1315,6 +1317,160 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), 
PyObject *Py_UNUSED(ignored))
     TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (1)");
     TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate 
narrowing did not narrow subject to 1");
 
+    // Test narrowing subject to constant from EQ predicate for int
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (1)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate 
narrowing did not narrow subject to 1");
+
+    // Resolving EQ predicate to False should not narrow subject for int
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
+    // Test narrowing subject to constant from NE predicate for int
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (1)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate 
narrowing did not narrow subject to 1");
+
+    // Resolving NE predicate to true should not narrow subject for int
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
+    // Test narrowing subject to constant from EQ predicate for float
+    subject = _Py_uop_sym_new_unknown(ctx);
+    PyObject *float_tenth_obj = PyFloat_FromDouble(0.1);
+    JitOptRef const_float_tenth = _Py_uop_sym_new_const(ctx, float_tenth_obj);
+    if (PyJitRef_IsNull(subject) || float_tenth_obj == NULL || 
PyJitRef_IsNull(const_float_tenth)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, 
JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (float)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, 
"predicate narrowing did not narrow subject to 0.1");
+
+    // Resolving EQ predicate to False should not narrow subject for float
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, 
JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
+    // Test narrowing subject to constant from NE predicate for float
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, 
JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (float)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, 
"predicate narrowing did not narrow subject to 0.1");
+
+    // Resolving NE predicate to true should not narrow subject for float
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, 
JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
+    // Test narrowing subject to constant from EQ predicate for str
+    subject = _Py_uop_sym_new_unknown(ctx);
+    PyObject *str_hello_obj = PyUnicode_FromString("hello");
+    JitOptRef const_str_hello = _Py_uop_sym_new_const(ctx, str_hello_obj);
+    if (PyJitRef_IsNull(subject) || str_hello_obj == NULL || 
PyJitRef_IsNull(const_str_hello)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, 
JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (str)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, 
"predicate narrowing did not narrow subject to hello");
+
+    // Resolving EQ predicate to False should not narrow subject for str
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, 
JIT_PRED_EQ);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
+    // Test narrowing subject to constant from NE predicate for str
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, 
JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false);
+    TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
did not const-narrow subject (str)");
+    TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, 
"predicate narrowing did not narrow subject to hello");
+
+    // Resolving NE predicate to true should not narrow subject for str
+    subject = _Py_uop_sym_new_unknown(ctx);
+    if (PyJitRef_IsNull(subject)) {
+        goto fail;
+    }
+    ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, 
JIT_PRED_NE);
+    if (PyJitRef_IsNull(ref)) {
+        goto fail;
+    }
+    _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true);
+    TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing 
incorrectly narrowed subject (inverted/true)");
+
     val_big = PyNumber_Lshift(_PyLong_GetOne(), PyLong_FromLong(66));
     if (val_big == NULL) {
         goto fail;

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to